In [6]:
import os
import torch
import json
import numpy as np
import pandas as pd
from PIL import Image
from open_clip import create_model_from_pretrained, get_tokenizer

# Load the model and tokenizer
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Set device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
print(f"Model loaded on device: {device}")

Model loaded on device: cpu


In [8]:
# Define input directories
image_folder = "../data/unified/imgs"
text_folder = "../data/unified/alt"
spec_folder = "../data/unified/specs"
output_folder = "./embeddings"
os.makedirs(output_folder, exist_ok=True)

batch_size = 50  # Adjust batch size as needed
context_length = 256

def generate_text_embeddings(folder, output_name):
    text_files = [f for f in os.listdir(folder) if f.lower().endswith(".txt") or f.lower().endswith(".json")]
    text_embeddings = []
    file_names = []
    
    for i in range(0, len(text_files), batch_size):
        batch_files = text_files[i:i + batch_size]
        texts = []
        
        for text_file in batch_files:
            text_path = os.path.join(folder, text_file)
            
            with open(text_path, "r", encoding="utf-8") as f:
                if text_file.lower().endswith(".json"):
                    try:
                        data = json.load(f)
                        text_desc = json.dumps(data)  # Convert JSON object to a string
                    except json.JSONDecodeError:
                        print(f"Warning: Could not parse JSON in {text_file}, skipping.")
                        continue
                else:
                    text_desc = f.read().strip()
            
            texts.append(text_desc)
            file_names.append(text_file)
        
        text_input = tokenizer(texts, context_length=context_length).to(device)
        
        with torch.no_grad():
            text_features = model.encode_text(text_input)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        text_embeddings.append(text_features.cpu().numpy())
        print(f"Processed {i + batch_size} of {len(text_files)} text files.")
    
    save_embeddings(text_embeddings, file_names, output_name)

# Function to generate image embeddings in batches
def generate_image_embeddings(folder, output_name):
    image_files = [f for f in os.listdir(folder) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
    image_embeddings = []
    file_names = []
    
    for i in range(0, len(image_files), batch_size):
        batch_files = image_files[i:i + batch_size]
        images = []
        
        for image_file in batch_files:
            image_path = os.path.join(folder, image_file)
            image = preprocess(Image.open(image_path).convert("RGB"))
            images.append(image)
            file_names.append(image_file)
        
        images_tensor = torch.stack(images).to(device)
        
        with torch.no_grad():
            image_features = model.encode_image(images_tensor)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        image_embeddings.append(image_features.cpu().numpy())
        print(f"Processed {i + batch_size} of {len(image_files)} image files.")
    
    save_embeddings(image_embeddings, file_names, output_name)

# Function to save embeddings as TSV and Parquet
def save_embeddings(embeddings, file_names, output_name):
    if not embeddings:
        print(f"No embeddings generated for {output_name}")
        return
    
    embeddings_np = np.vstack(embeddings)
    df = pd.DataFrame(embeddings_np, columns=[f"dim_{i}" for i in range(embeddings_np.shape[1])])
    df.insert(0, "Filename", file_names)
    
    tsv_path = os.path.join(output_folder, f"{output_name}.tsv")
    parquet_path = os.path.join(output_folder, f"{output_name}.parquet")
    
    df.to_csv(tsv_path, sep="\t", index=False)
    df.to_parquet(parquet_path, index=False)
    
    print(f"Embeddings saved: {tsv_path} and {parquet_path}")

In [9]:
# Generate embeddings
print("Generating text embeddings...")
generate_text_embeddings(text_folder, "text_embeddings")

print("Generating specification embeddings...")
generate_text_embeddings(spec_folder, "spec_embeddings")

print("Generating image embeddings...")
generate_image_embeddings(image_folder, "image_embeddings")

print("All embeddings generated and saved successfully!")

Generating text embeddings...
Processed 50 of 3200 text files.
Processed 100 of 3200 text files.
Processed 150 of 3200 text files.
Processed 200 of 3200 text files.
Processed 250 of 3200 text files.
Processed 300 of 3200 text files.
Processed 350 of 3200 text files.
Processed 400 of 3200 text files.
Processed 450 of 3200 text files.
Processed 500 of 3200 text files.
Processed 550 of 3200 text files.
Processed 600 of 3200 text files.
Processed 650 of 3200 text files.
Processed 700 of 3200 text files.
Processed 750 of 3200 text files.
Processed 800 of 3200 text files.
Processed 850 of 3200 text files.
Processed 900 of 3200 text files.
Processed 950 of 3200 text files.
Processed 1000 of 3200 text files.
Processed 1050 of 3200 text files.
Processed 1100 of 3200 text files.
Processed 1150 of 3200 text files.
Processed 1200 of 3200 text files.
Processed 1250 of 3200 text files.
Processed 1300 of 3200 text files.
Processed 1350 of 3200 text files.
Processed 1400 of 3200 text files.
Processed 

In [10]:
import pandas as pd

# Load the TSV file
tsv_path = "./embeddings/text_embeddings.tsv"
df_tsv = pd.read_csv(tsv_path, sep="\t")

# Display the first few rows
print(df_tsv.head())

                                           Filename     dim_0     dim_1  \
0     EX_SPEC_ALIGNMENT_CHART_sw_1_2_s_0_7_cc_2.txt -0.018945  0.034008   
1  BRCA-EU-fc8130e0-a8b4-d80d-e040-11ac0c483272.txt -0.012375 -0.020747   
2               breast_cancer_sw_1_0_s_0_5_cc_0.txt -0.005983  0.016782   
3                EX_SPEC_CIRCOS_sw_1_0_s_0_7_oc.txt -0.050138 -0.004010   
4                 rule-mark_p_0_sw_1_2_s_0_7_oc.txt -0.013534  0.004531   

      dim_2     dim_3     dim_4     dim_5     dim_6     dim_7     dim_8  ...  \
0 -0.007164  0.047764 -0.012338 -0.039841 -0.014265 -0.042667  0.077361  ...   
1 -0.045096  0.018411  0.031258 -0.022743 -0.015150 -0.126027  0.024359  ...   
2 -0.036115 -0.000103  0.019164 -0.004601 -0.040937 -0.090851  0.039501  ...   
3 -0.044706  0.001753 -0.013668 -0.011940 -0.032503 -0.069678  0.001210  ...   
4 -0.012634 -0.000794  0.016319 -0.029607  0.017443 -0.008724  0.016731  ...   

    dim_502   dim_503   dim_504   dim_505   dim_506   dim_507   dim_

In [12]:
# Load the Parquet file
parquet_path = "./embeddings/spec_embeddings.parquet"
df_parquet = pd.read_parquet(parquet_path)

# Display the first few rows
print(df_parquet.head())

                                            Filename     dim_0     dim_1  \
0       EX_SPEC_CIRCULR_RANGE_sw_0_7_s_0_7_cc_0.json -0.001649 -0.064245   
1             EX_SPEC_GREMLIN_sw_1_2_s_0_7_cc_1.json -0.021860 -0.027158   
2  gray_heatmap_sw_1_0_s_1_0_oc_sw_0_7_s_1_0_cc_0...  0.009591 -0.089052   
3                        TEXT_sw_0_7_s_1_0_cc_0.json  0.037189 -0.026111   
4  PBCA-DE-2009e5e7-1796-445b-8677-46b3804fe0bf.json  0.027299 -0.008137   

      dim_2     dim_3     dim_4     dim_5     dim_6     dim_7     dim_8  ...  \
0 -0.095751  0.007678  0.042844 -0.061872 -0.043238 -0.023653 -0.016709  ...   
1 -0.045956 -0.003640 -0.009432 -0.034230 -0.024108 -0.050020  0.005480  ...   
2 -0.061562  0.053112  0.019418 -0.013294 -0.009502 -0.045159  0.033293  ...   
3 -0.063490  0.049418  0.078321 -0.031790 -0.013515 -0.027990  0.014375  ...   
4 -0.057222 -0.018383 -0.002513 -0.033442  0.006589 -0.044671  0.046795  ...   

    dim_502   dim_503   dim_504   dim_505   dim_506   dim_507 