# Embedding 3D Data with M3D-CLIP in Julia using PythonCall
The high-level overview of this notebooks is that we will be loading in CT data and obtaining an embedding to be used in other models with the `M3D-CLIP` foundation model.

To do so, we need to do the following:
1. Create a Python Environment for PythonCall to properly support the `M3D` model as well as the `CLIP` tokenizer
2. Use PythonCall to load in the necessary Python modules to interact with `M3D`
3. Make use of the embedded data!

## 0. Set Up Our Julia Environment
The only packages we will need are `CondaPkg.jl` to manage our Python environemtn and `PythonCall.jl` to interface between Python and Julia.

In [1]:
using Pkg
Pkg.activate(".")
Pkg.add("CondaPkg")
Pkg.add("PythonCall")
Pkg.add("JLD2")

[32m[1m  Activating[22m[39m new project at `/content`
[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m micromamba_jll ─ v1.5.8+0
[32m[1m   Installed[22m[39m pixi_jll ─────── v0.41.3+0
[32m[1m   Installed[22m[39m MicroMamba ───── v0.1.14
[32m[1m   Installed[22m[39m CondaPkg ─────── v0.2.28
[32m[1m   Installed[22m[39m Pidfile ──────── v1.3.0
[32m[1m   Installed[22m[39m JSON3 ────────── v1.14.2
[32m[1m   Installed[22m[39m StructTypes ──── v1.11.0
[32m[1m    Updating[22m[39m `/content/Project.toml`
  [90m[992eb4ea] [39m[92m+ CondaPkg v0.2.28[39m
[32m[1m    Updating[22m[39m `/content/Manifest.toml`
  [90m[992eb4ea] [39m[92m+ CondaPkg v0.2.28[39m
  [90m[692b3bcd] [39m[92m+ JLLWrappers v1.7.0[39m
  [90m[0f8b85d8] [39m[92m+ JSON3 v1.14.2[39m
  [90m[0b3b1443] [39m[92m+ MicroMamba v0.1.14[39m
  [90m[69de0a69] [39m[92m+ Parser

## 1. Create Our Python Environment
To create our Python environment, using `CondaPkg.jl`, we will do the following actions:
1. Install python 3.11 (we need that exact version for a few Python libraries to function)
2. Read in a `requirements.txt` file to extract the necessary Python libraries and their specific version numbers
3. Use `pip` to install the libraries

In [2]:
import CondaPkg

### Installing Python 3.11

In [3]:
CondaPkg.add("python"; version=">3.10,<3.12")

[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /content/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /root/.julia/packages/PythonCall/WMWY0/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /root/.julia/packages/Reactant/OgayD/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mResolving changes
[32m[1m             [22m[39m[32m+ jax (pip)[39m
[32m[1m             [22m[39m[32m+ libstdcxx-ng[39m
[32m[1m             [22m[39m[32m+ python[39m
[32m[1m             [22m[39m[32m+ uv[39m
[32m[1m    CondaPkg [22m[39m[0mInitialising pixi
[32m[1m             [22m[39m│ [90m/root/.julia/artifacts/cefba4912c2b400756d043a2563ef77a0088866b/bin/pixi[39m
[32m[1m             [22m[39m│ [90minit[39m
[32m[1m             [22m[39m│ [90m--format pixi[39m
[32m[1m             [22m[39m└ [90m/content/.CondaPkg[39m
✔ Created /content/.CondaPkg/pixi.toml
[32m[1m    CondaPkg [22m[39m[0mWrote /content/.CondaPkg/p

### Parsing through the `requirements.txt` File
Directly using a `requirements.txt` file is not yet supported in `CondaPkg.jl`. However, we can use basic file reading operations in `Base` to get what we want

In [1]:
function parse_pip_requirements(path)
    packages = []
    versions = []
    open(path) do file
        for line in eachline(file)
            # Skip empty lines and comments
            if isempty(line) || line[1] == '#'
                continue
            end

            # Split the line into package and version
            package, version = split(line, "==")
            version = "==" * version
            push!(packages, package)
            push!(versions, version)
        end
    end
    return packages, versions
end

packages, versions = parse_pip_requirements("requirements.txt")
println(packages)
println(versions)

Any["deepspeed", "einops", "evaluate", "matplotlib", "monai", "nibabel", "numpy", "opencv_python", "pandas", "peft", "Pillow", "pycocotools", "Requests", "rouge", "safetensors", "scipy", "simple_slice_viewer", "SimpleITK", "torch", "torchvision", "tqdm", "transformers", "tweepy"]
Any["==0.13.4", "==0.8.0", "==0.4.1", "==3.8.4", "==1.2.0", "==5.2.1", "==1.26.4", "==4.9.0.80", "==2.2.2", "==0.8.2", "==10.3.0", "==2.0.7", "==2.31.0", "==1.0.1", "==0.4.3", "==1.13.0", "==0.97", "==2.3.1", "==2.6.0", "==0.21.0", "==4.66.2", "==4.39.1", "==4.14.0"]


### Installing the Python Libraries
We can install each library individually with `pip` through `CondaPkg.add_pip(package_name; version=package_version)`

In [5]:
for i in 1:length(packages)
    CondaPkg.add_pip(packages[i]; version=versions[i])
end

[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /content/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /root/.julia/packages/PythonCall/WMWY0/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mFound dependencies: /root/.julia/packages/Reactant/OgayD/CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mResolving changes
[32m[1m             [22m[39m[32m+ deepspeed (pip)[39m
[32m[1m    CondaPkg [22m[39m[0mInitialising pixi
[32m[1m             [22m[39m│ [90m/root/.julia/artifacts/cefba4912c2b400756d043a2563ef77a0088866b/bin/pixi[39m
[32m[1m             [22m[39m│ [90minit[39m
[32m[1m             [22m[39m│ [90m--format pixi[39m
[32m[1m             [22m[39m└ [90m/content/.CondaPkg[39m
✔ Created /content/.CondaPkg/pixi.toml
[32m[1m    CondaPkg [22m[39m[0mWrote /content/.CondaPkg/pixi.toml
[32m[1m             [22m[39m│ [90m[dependencies][39m
[32m[1m             [22m[39m│ [90muv = ">=0.4"[39m
[32m[1m             [

### Verify That we Have the Correct Python Version and Libraries Installed

In [6]:
CondaPkg.status()

[92mCondaPkg Status[39m[0m[1m /content/CondaPkg.toml[22m
[36m[1mEnvironment[22m[39m
  /content/.CondaPkg/.pixi/envs/default
[36m[1mPackages[22m[39m
  python v3.11.12[90m (>3.10,<3.12)[39m
[36m[1mPip packages[22m[39m
  deepspeed v0.13.4[90m (==0.13.4)[39m
  einops v0.8.0[90m (==0.8.0)[39m
  evaluate v0.4.1[90m (==0.4.1)[39m
  matplotlib v3.8.4[90m (==3.8.4)[39m
  monai v1.2.0[90m (==1.2.0)[39m
  nibabel v5.2.1[90m (==5.2.1)[39m
  numpy v1.26.4[90m (==1.26.4)[39m
  opencv-python v4.9.0.80[90m (==4.9.0.80)[39m
  pandas v2.2.2[90m (==2.2.2)[39m
  peft v0.8.2[90m (==0.8.2)[39m
  pillow v10.3.0[90m (==10.3.0)[39m
  pycocotools v2.0.7[90m (==2.0.7)[39m
  requests v2.31.0[90m (==2.31.0)[39m
  rouge v1.0.1[90m (==1.0.1)[39m
  safetensors v0.4.3[90m (==0.4.3)[39m
  scipy v1.13.0[90m (==1.13.0)[39m
  simple-slice-viewer v0.97[90m (==0.97)[39m
  simpleitk v2.3.1[90m (==2.3.1)[39m
  torch v2.6.0[90m (==2.6.0)[39m
  torchvision v0.21.0[90m (=

## 2. Using `PythonCall.jl` to Load in the Model
Using `PythonCall.jl`, we simply load in the `M3D-CLIP` model and its tokenizer from Hugging Face, and we can interact with the Python model in a very Julian way.

In [1]:
import PythonCall
torch = PythonCall.pyimport("torch")
np = PythonCall.pyimport("numpy")
AutoTokenizer = PythonCall.pyimport("transformers").AutoTokenizer
AutoModel = PythonCall.pyimport("transformers").AutoModel

tokenizer = AutoTokenizer.from_pretrained("GoodBaiBai88/M3D-CLIP", model_max_length=512, padding_side="right", use_fast=false)
model = AutoModel.from_pretrained("GoodBaiBai88/M3D-CLIP", trust_remote_code=true)



Python:
M3DCLIP(
  (vision_encoder): ViT(
    (patch_embedding): PatchEmbeddingBlock(
      (patch_embeddings): Sequential(
        (0): Rearrange('b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1…
        (1): Linear(in_features=1024, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (blocks): ModuleList(
[90m    ... 57 more lines ...[39m
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (mm_vision_proj): Linear(in_features=768, out_features=768, bias=True)
  (mm_language_proj): Linear(in_features=768, out_features=768, bias=True)
)

## 3. Obtaining the Embeddings
To get the embeddings, all we have to do is use the model's built in encoder which will embed the image and return its embedding.

In [None]:
using JLD2

image = JLD2.load("data/example.jld2", "jl_image")
torch_image = torch.tensor(PythonCall.pyrowlist(image)).unsqueeze(0)
torch_image.shape

torch_encoded_image = model.encode_image(torch_image)
torch_image.shape, torch_encoded_image.shape

(<py torch.Size([1, 1, 32, 256, 256])>, <py torch.Size([1, 2049, 768])>)

In [40]:
encoded_image = PythonCall.pyconvert(Array, torch_encoded_image.detach().numpy())

1×2049×768 Array{Float32, 3}:
[:, :, 1] =
 0.0206281  -0.006389  -0.00490209  -0.0278195  …  -0.00684712  -0.00559487

[:, :, 2] =
 0.0736189  0.069139  0.0704843  0.057757  …  0.0697616  0.0697715  0.0685673

[:, :, 3] =
 0.00992233  0.00230995  -0.000363045  0.0486071  …  0.000932447  0.00255924

;;; … 

[:, :, 766] =
 0.0498103  0.0209607  0.0230589  …  0.0226209  0.0219514  0.0209882

[:, :, 767] =
 0.00530913  0.0151308  0.0169863  …  0.0146035  0.0167086  0.0118581

[:, :, 768] =
 -0.0554023  -0.0689056  -0.0704629  …  -0.0680503  -0.0677214  -0.0672402