In [None]:
import sys
from unittest.mock import MagicMock

# Mock flash_attn module
sys.modules['flash_attn'] = MagicMock()
import os

import torch
import pandas as pd

from procyon.data.inference_utils import (
    create_caption_input_simple,
    create_qa_input_simple,
    uniprot_id_to_index,
    ProCyonQAInference,
)
from procyon.model.model_unified import UnifiedProCyon
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch

  import pynvml  # type: ignore[import]


[2025-12-07 13:52:44,575] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  import pkg_resources


In [None]:
from accelerate import dispatch_model
# -------------------------
# 1. Load checkpoint & args
# -------------------------
checkpoint_path = "/oscar/data/rsingh47/drraghav/procupine/ProCyon-Full"
CKPT_NAME = checkpoint_path

data_args = torch.load(os.path.join(CKPT_NAME, "data_args.pt"))

device = torch.device('cuda')
model, _ = UnifiedProCyon.from_pretrained(checkpoint_dir=CKPT_NAME)
model.eval()
model.bfloat16()

# --------------------------------------------------------
# 2. Generate initial device map using accelerate
# --------------------------------------------------------
# This gives us a starting point, but we will override
device_map = infer_auto_device_map(
    model,
    max_memory={0: "20GB", 1: "20GB"},
    no_split_module_classes=["LlamaDecoderLayer"]
)

# --------------------------------------------------------
# 3. Force text encoder fully on GPU0
# --------------------------------------------------------
fixed_map = dict(device_map)

for k in fixed_map.keys():
    if k.startswith("text_encoder"):
        fixed_map[k] = 0

# --------------------------------------------------------
# 4. Put protein modules on GPU1
# --------------------------------------------------------
protein_modules = [
    "protein_seq_embeddings",
    "protein_struct_embeddings",
    "domain_embeddings",
    "token_projectors",
    "drug_structure_embeddings",
    "aaseq_shared_projector",
    "aaseq_lm_projector",
    "contrastive_head",
]

for k in protein_modules:
    if k in fixed_map:
        fixed_map[k] = 1

# --------------------------------------------------------
# 5. Make sure input embeddings are on GPU0
# --------------------------------------------------------
fixed_map["input_embeddings"] = 0
fixed_map["text_encoder.model.lm_head"] = 0

# --------------------------------------------------------
# 6. Dispatch model according to this fixed_map
# --------------------------------------------------------
model = dispatch_model(model, fixed_map)
model.eval()
model.bfloat16()

print(f"Text Encoder is on: {model.text_encoder.model.device}")
print(f"Protein Embeddings are on: {model.protein_seq_embeddings.weight.device}")

updating model args DATA_DIR from /n/holystore01/LABS/mzitnik_lab/Lab/PLM -> /oscar/data/rsingh47/drraghav/procupine/ProCyon-Instruct
updating stale DATA_DIR for model arg: go_embeddings_path
updating stale DATA_DIR for model arg: pfam_embeddings_path
updating stale DATA_DIR for model arg: drugbank_embeddings_path
updating stale DATA_DIR for model arg: reactome_embeddings_path
updating stale DATA_DIR for model arg: omim_embeddings_path
updating stale DATA_DIR for model arg: ec_embeddings_path
updating stale DATA_DIR for model arg: protein_seq_embeddings_path
updating stale DATA_DIR for model arg: protein_struct_embeddings_path
updating stale DATA_DIR for model arg: protein_embeddings_idmap_path
updating stale DATA_DIR for model arg: drug_struct_embeddings_path
updating stale DATA_DIR for model arg: domain_embeddings_path
updating stale DATA_DIR for model arg: domain_embeddings_idmap_path
updating stale DATA_DIR for model arg: mouse_ortholog_embeddings_path
updating stale DATA_DIR for m

Using sep_token, but it is not set yet.
Using pad_token, but it is not set yet.


Processing zero checkpoint '/oscar/data/rsingh47/drraghav/procupine/ProCyon-Full/global_step59469'
Detected checkpoint of type zero stage ZeroStageEnum.gradients, world_size: 32
Parsing checkpoint created by deepspeed==0.12.4
Reconstructed fp32 state dict with 322 params 8141117441 elements
Text Encoder is on: cuda:0
Protein Embeddings are on: cuda:0


In [None]:
## This runs the Procyon inference model

random_idx = 1
sample_row = df_clean.iloc[random_idx]
gene_name = sample_row['gene_name']
sample_prot_id = int(sample_row['procyon_id'])
protein_ids = [sample_prot_id]
input_simple = create_caption_input_simple(
    input_aaseq_ids=protein_ids,
    data_args=data_args,
    # The `instruction_source_dataset` and `instruction_source_relation` here control the style
    # of pre-templated instruction used in these queries. In particular, here we query for UniProt-style
    # functional descriptions.
    instruction_source_dataset="uniprot",
    instruction_source_relation="all",
    aaseq_type="protein",
    task_type="caption",
    icl_example_number=1,
     device=device,
)
def move_to_device(obj, device="cuda:0"):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [move_to_device(x, device) for x in obj]
    else:
        return obj

input_simple = move_to_device(input_simple, "cuda:0")

text_gen_args = {
    "method": "beam",
    # Maximum length of generated text.
    "max_len": 200,
    # Total number of beams maintained per input. `beam_size` / `beam_group_size` = number of phenotypes returned per input.
    "beam_size": 20,
    # Size of the individual beam groups in DBS.
    "beam_group_size": 2,
    # Penalty applied to repetition within a beam group.
    "diversity_penalty": 0.8,
}

out_tokens, log_probs, output_logits, out_text = model.generate(
    inputs=input_simple,
    aaseq_type="protein",
    **text_gen_args
)
output_phenotypes = [
    phen for i, phen in enumerate(out_text[0]) if i % text_gen_args["beam_group_size"] == 0
]
qa_model = ProCyonQAInference(model, device=device)

# Try QA filtering
results = []
for i, query_text in enumerate(output_phenotypes):
    input_qa_simple = create_qa_input_simple(
        input_aaseq_ids=protein_ids,
        data_args=data_args,
        input_description=query_text,
        instruction_source_dataset="uniprot",
        instruction_source_relation="all",
        aaseq_type="protein",
        icl_example_number=1,
        device=device,
    )

    with torch.no_grad():
        model_qa_out = qa_model(input_qa_simple)

    yes_prob = model_qa_out["pred"][0, qa_model.yes_token].item()
    no_prob = model_qa_out["pred"][0, qa_model.no_token].item()


    print(f"TEXT {i} --------------------------------------------")
    print(query_text)
    print(f"Yes: {yes_prob:0.3f}")
    print(f"No: {no_prob:0.3f}")

    results.append({
        "phenotype": query_text,
        "yes_prob": yes_prob
    })

results = pd.DataFrame(results)

In [None]:
import pandas as pd
import numpy as np

CSV_PATH = "pbmk.csv"
EMBEDDING_COL = "biogpt_embedding"
TEXT_COL = "summary"
ID_COL = "uniprot_id"
GENE_COL = "gene_name"

# Expected Dimensions
INPUT_DIM = 1024 # BioGPT dimension
LLAMA_DIM = 4096

def clean_and_load_data(csv_path):
    print(f"Loading {csv_path}...")
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print("File not found!")
        return None, None

    # 1. Filter Missing Data
    initial_len = len(df)
    df = df.dropna(subset=[ID_COL, TEXT_COL, EMBEDDING_COL])
    print(f"Filtered {initial_len - len(df)} rows with missing data. Remaining: {len(df)}")

    # 2. Parse Embeddings
    print(f"Parsing embedding column: '{EMBEDDING_COL}'...")
    def parse_embedding(x):
        try:
            # Handle string format "0.0,0.0,0.0"
            if isinstance(x, str):
                clean_str = x.replace('[', '').replace(']', '').replace('"', '').replace("'", "")
                return np.fromstring(clean_str, sep=',')
            return np.array(x)
        except:
            return np.zeros(INPUT_DIM)

    df['vec'] = df[EMBEDDING_COL].apply(parse_embedding)

    # 3. Detect Dimension
    valid_vecs = df[df['vec'].apply(len) > 1]['vec']
    if len(valid_vecs) > 0:
        actual_dim = len(valid_vecs.iloc[0])
    else:
        actual_dim = INPUT_DIM
    print(f"Detected Embedding Dimension: {actual_dim}")

    # 4. Map UniProt IDs to ProCyon Indices
    print("Mapping UniProt IDs to ProCyon Indices...")

    valid_indices = []
    valid_rows = []

    for idx, row in df.iterrows():
        uid = row[ID_COL]
        try:
            # Try to map using ProCyon's function
            # If the ID isn't in ProCyon's vocab, this might fail, so we catch it
            procyon_idx = uniprot_id_to_index(uid)
            valid_indices.append(procyon_idx)
            valid_rows.append(idx)
        except Exception:
            # ID not found in ProCyon database - skip it
            continue

    # Filter DF to only include mappable proteins
    df_mapped = df.loc[valid_rows].copy()
    df_mapped['procyon_id'] = valid_indices

    print(f"Final Count: {len(df_mapped)} samples ready for training.")
    return df_mapped, actual_dim

df_clean, DETECTED_DIM = clean_and_load_data(CSV_PATH)

Loading pbmk.csv...
Filtered 40 rows with missing data. Remaining: 636
Parsing embedding column: 'biogpt_embedding'...
Detected Embedding Dimension: 1024
Mapping UniProt IDs to ProCyon Indices...
Final Count: 116 samples ready for training.


In [None]:
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Assuming 'checkpoint_path' is defined from the initial model load cell
try:
    # 1. Fetch the path (robust check)
    llama_path = model.text_encoder.model.config._name_or_path
except AttributeError:
    llama_path = getattr(torch.load(os.path.join(checkpoint_path, "model_args.pt")), 'model_name_or_path', None)

if llama_path:
    print(f"Attempting to load tokenizer from: {llama_path}")
    # 2. Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(llama_path)

    # 3. Fix missing pad token (Crucial for attention masks)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print("Fixed padding token.")
else:
    raise ValueError("CRITICAL: Could not find the Llama path to load the tokenizer.")

Attempting to load tokenizer from: /oscar/data/rsingh47/drraghav/procupine/ProCyon-Instruct/model_weights/llama-3-8b


Using pad_token, but it is not set yet.


Fixed padding token.


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F

class GeneAdapterDataset(Dataset):
    """Dataset for aligning Gene Embeddings (vec) to Gene Summary Text (summary)."""
    def __init__(self, df, tokenizer, gene_embedding_col='vec', summary_col='summary', prot_id_col='procyon_id'):
        self.tokenizer = tokenizer
        self.gene_feats = np.stack(df[gene_embedding_col].values)
        self.texts = df[summary_col].tolist()
        self.prot_ids = df[prot_id_col].tolist()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Tokenize the target text (Gene Summary)
        gene_text = self.texts[idx]
        tokenized_output = self.tokenizer(gene_text,
                                          padding='max_length',
                                          truncation=True,
                                          max_length=128,
                                          return_tensors='pt')

        # Prepare inputs
        return {
            'gene_feat': self.gene_feats[idx].astype(np.float32),
            'prot_id': self.prot_ids[idx],
            'gene_text_labels': tokenized_output['input_ids'].squeeze(0),
            'attention_mask': tokenized_output['attention_mask'].squeeze(0)
        }

def collate_fn(batch):
    # Separate lists of numpy arrays and tensors
    gene_feats = [item['gene_feat'] for item in batch]
    prot_ids = [item['prot_id'] for item in batch]
    gene_text_labels = [item['gene_text_labels'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]

    # Convert gene features to a torch tensor (already padded by numpy.stack)
    gene_feats = torch.tensor(np.stack(gene_feats), dtype=torch.float)
    prot_ids = torch.tensor(prot_ids, dtype=torch.long)

    # Pad text and mask sequences
    gene_text_labels_padded = rnn_utils.pad_sequence(gene_text_labels, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks_padded = rnn_utils.pad_sequence(attention_masks, batch_first=True, padding_value=0)

    # Return final batch dictionary
    return {
        'gene_feat': gene_feats,
        'prot_id': prot_ids,
        'gene_text_labels': gene_text_labels_padded,
        'attention_mask': attention_masks_padded,
    }

print("Instantiating dataset and train_loader...")

# Use the cleaned DataFrame (df_clean) and the loaded tokenizer
dataset = GeneAdapterDataset(df_clean, tokenizer)

# Set batch size to 1 to manage memory (as discussed)
# train_loader is the variable needed for the training loop
train_loader = DataLoader(dataset,
                          batch_size=1,
                          shuffle=True,
                          collate_fn=collate_fn)

print(f"train_loader created with {len(train_loader)} batches.")

Instantiating dataset and train_loader...
train_loader created with 116 batches.


In [None]:
## TRAIN THE ADAPTER

import torch.nn as nn
import torch.optim as optim
import os
import gc

# --- TRAINING PARAMETERS ---
NUM_EPOCHS = 3
ACCUMULATION = 8

torch.cuda.empty_cache()
gc.collect()

class GeneAdapter(nn.Module):
    def __init__(self, input_dim=1024, output_dim=4096):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim),
            nn.LayerNorm(output_dim)
        )
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
                torch.nn.init.zeros_(m.bias)
    def forward(self, x):
        out = self.net(x).unsqueeze(1)
        out = nn.functional.normalize(out, p=2, dim=-1)
        out.requires_grad_(True)
        return out

class ProCyonIntegrated(nn.Module):
    def __init__(self, original_procyon, adapter, tokenizer):
        super().__init__()
        self.procyon = original_procyon
        self.adapter = adapter
        self.tokenizer = tokenizer

        for p in self.procyon.parameters(): p.requires_grad = False
        self.procyon.text_encoder.model.gradient_checkpointing_enable()
        for p in self.adapter.parameters(): p.requires_grad = True

    def forward(self, batch):
        # Text Encoder (Llama) is on GPU 0
        main_device = torch.device("cuda:0")
        prot_device = self.procyon.protein_seq_embeddings.weight.device # GPU 1

        # A. Gene Branch (Adapter)
        # Adapter and Gene Input MUST be on GPU 0
        g_input = batch['gene_feat'].to(main_device).to(torch.bfloat16)
        g_embed = self.adapter(g_input)

        # B. Protein Branch
        p_ids = batch['prot_id'].to(prot_device) # Protein IDs stay on GPU 1
        p_raw = self.procyon.protein_seq_embeddings(p_ids)
        p_embed = self.procyon.token_projectors['aaseq'](p_raw).to(main_device)
        if p_embed.dim() == 2: p_embed = p_embed.unsqueeze(1)

        # C. Ground Truth
        target_ids = batch['gene_text_labels'].to(main_device)
        target_embeds = self.procyon.input_embeddings(target_ids)

        # D. The Input Sandwich
        inputs_embeds = torch.cat([g_embed, p_embed, target_embeds], dim=1)

        # E. Masks & Labels setup
        batch_size = len(p_ids)
        target_mask = batch['attention_mask'].to(main_device)
        ones = torch.ones((batch_size, 2), device=main_device)
        full_mask = torch.cat([ones, target_mask], dim=1)
        ignore = torch.full((batch_size, 2), -100, device=main_device, dtype=torch.long)
        full_labels = torch.cat([ignore, target_ids], dim=1)

        return self.procyon.text_encoder.model(
            inputs_embeds=inputs_embeds,
            attention_mask=full_mask,
            labels=full_labels
        )

# --- 3. Initialization (Adapter on GPU 0) ---
adapter = GeneAdapter(input_dim=DETECTED_DIM, output_dim=LLAMA_DIM).to("cuda:0").bfloat16()

wrapper = ProCyonIntegrated(model, adapter, tokenizer)
optimizer = optim.AdamW(adapter.parameters(), lr=1e-4)

# --- 4. Training Loop ---
print(f"\n--- Starting Integrated Adapter Training (Batch Size 1) ---")
wrapper.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    num_batches = 0
    print(f"\n=== Epoch {epoch + 1}/{NUM_EPOCHS} ===")

    for i, batch in enumerate(train_loader):
        try:
            # Forward
            outputs = wrapper(batch)

            # Divide loss by accumulation steps
            loss = outputs.loss / ACCUMULATION
            loss.backward()

            # Step (Every 8 batches)
            if (i + 1) % ACCUMULATION == 0:
                torch.nn.utils.clip_grad_norm_(adapter.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

                # Logging
                actual_loss = loss.item() * ACCUMULATION
                total_loss += actual_loss
                num_batches += 1

                if num_batches % 5 == 0:
                    print(f"Batch {i}: Loss = {actual_loss:.4f}")

        except Exception as e:
            # Catch all exceptions during training, assuming memory/gradient issue
            print(f"FAILURE at Batch {i}: {e}. Retrying after aggressive cleanup...")
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            continue

    # End of Epoch Stats
    avg_loss = total_loss / max(1, num_batches)
    print(f"--- Epoch {epoch + 1} Complete. Avg Loss: {avg_loss:.4f} ---")

    save_path = os.path.join(SAVE_DIR, f"sc_adapter_integrated_epoch_{epoch+1}.pt")
    torch.save(adapter.state_dict(), save_path)
    print(f"Saved checkpoint: {save_path}")

print("Training Complete.")

In [None]:
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
import numpy as np
import gc

def generate_sc_phenotype_icl(wrapper, prot_id_int, gene_vec_np, data_args, device):
    """Generates phenotype candidates by injecting the Gene Embedding into the ProCyon ICL input structure."""
    wrapper.eval()

    # --- A. Generate the working ICL Input Dictionary ---
    input_simple = create_caption_input_simple(
        input_aaseq_ids=[prot_id_int],
        data_args=data_args,
        instruction_source_dataset="uniprot",
        instruction_source_relation="all",
        aaseq_type="protein",
        task_type="caption",
        icl_example_number=1,
        device=device,
    )

    # --- B. Calculate Gene Embedding ---
    main_device = wrapper.procyon.text_encoder.model.device
    g_feat = torch.tensor(gene_vec_np, dtype=torch.float).unsqueeze(0).to(main_device).bfloat16()

    with torch.no_grad():
        g_emb = wrapper.adapter(g_feat)

    # --- C. INJECT THE GENE EMBEDDING ---
    original_instruction = input_simple['instructions'][0]
    new_instruction = "Context: <|gene|>\n" + original_instruction.replace("Protein: <|protein|>", "Protein: <|protein|>")
    input_simple['instructions'][0] = new_instruction

    gene_emb_numpy = g_emb.squeeze(0).cpu().float().numpy()

    input_simple['data']['gene'] = [gene_emb_numpy]
    input_simple['input']['gene'] = [[0]]

    # --- D. Run Generation ---
    text_gen_args = {
        "method": "beam", "max_len": 200, "beam_size": 20,
        "beam_group_size": 2, "diversity_penalty": 0.8,
    }

    out_tokens, log_probs, output_logits, out_text = wrapper.procyon.generate(
        inputs=input_simple,
        aaseq_type="protein",
        **text_gen_args
    )

    output_phenotypes = [
        phen for i, phen in enumerate(out_text[0]) if i % text_gen_args["beam_group_size"] == 0
    ]
    # Return the full tensor embedding for the QA step
    return output_phenotypes, input_simple, g_emb

def predict_phenotype_score_icl(wrapper, prot_id_int, g_emb_vector, data_args, candidate_text, device):
    """QA Step: Scores the candidate text using the integrated model with ICL."""
    wrapper.eval()

    with torch.no_grad():
        # Use the ProCyon QA helper for scoring
        input_qa_simple = create_qa_input_simple(
            input_aaseq_ids=[prot_id_int],
            data_args=data_args,
            input_description=candidate_text,
            instruction_source_dataset="uniprot",
            instruction_source_relation="all",
            aaseq_type="protein",
            icl_example_number=1,
            device=device,
        )

        gene_emb_numpy = g_emb_vector.squeeze(0).cpu().float().numpy()

        input_qa_simple['data']['gene'] = [gene_emb_numpy]
        input_qa_simple['input']['gene'] = [[0]]

        # Run QA model
        qa_model = ProCyonQAInference(wrapper.procyon, device=device)
        model_qa_out = qa_model(input_qa_simple)

        # Extract "Yes" Probability
        yes_prob = model_qa_out["pred"][0, qa_model.yes_token].item()

        return yes_prob

In [None]:
test_index = 1
sample_row = df_clean.iloc[test_index]

gene_name = sample_row['gene_name']
sample_gene_vec = sample_row['vec']
sample_prot_id = int(sample_row['procyon_id'])
true_summary = sample_row['summary']
TARGET_UNIPROT_ID = sample_row['uniprot_id']

# Define the label dynamically
protein_name_label = f"UniProt ID: {TARGET_UNIPROT_ID} (Gene: {gene_name})"

print(f"\n=== Gene-Augmented Phenotype Generation for {protein_name_label} ===")
print(f"Gene Context Summary: {true_summary[:100]}...\n")

print("1. Generating candidates with ICL Injection (Gene-Augmented)...")

# Use the established functions
candidates, qa_input_template, g_emb_vector = generate_sc_phenotype_icl(
    wrapper, sample_prot_id, sample_gene_vec, data_args, device
)

print("2. Scoring candidates (QA Check)...")

results = []
for i, query_text in enumerate(candidates):
    yes_prob = predict_phenotype_score_icl(
        wrapper, sample_prot_id, g_emb_vector, data_args, query_text, device
    )

    results.append({ "phenotype": query_text, "yes_prob": yes_prob })

# Display Results
results.sort(key=lambda x: x['yes_prob'], reverse=True)

for i, res in enumerate(results):
    print(f"\nRANK {i+1} (Confidence: {res['yes_prob']:.4f}) {'-'*30}")
    print(res['phenotype'].strip())

In [None]:
!pip install transformers
!pip install bert-score
!pip uninstall -y flash_attn

In [None]:
from bert_score import BERTScorer
from transformers import pipeline

In [None]:
ground_truth_Q9UKD2 = ["Component of the ribosome assembly machinery. Nuclear paralog of the ribosomal protein P0, it binds pre-60S subunits at an early stage of assembly in the nucleolus, and is replaced by P0 in cytoplasmic pre-60S subunits and mature 80S ribosomes."]
phenotypes_with_genes_Q9UKD2 = ["Required for pre-rRNA splicing as component of the spliceosome. Binds to the 5' splice site of pre-18S ribosomal RNA by RNA polymerase I, which is a component of the 5S rRNA."]
phenotypes_only_protein_Q9UKD2 = ["Involved in pre-mRNA splicing as component of the spliceosome. Binds to the 5S rRNA. Binds to the 5S rRNA. Binds to the 5S rRNA. Binds to the 5S rRNA."]
#########################################
ground_truth_Q96L58 = ["Beta-1,3-galactosyltransferase that transfers galactose from UDP-galactose to substrates with a terminal beta-linked galactose residue. Has a preference for galactose-beta-1,4-xylose that is found in the linker region of glycosaminoglycans, such as heparan sulfate and chondroitin sulfate. Has no activity towards substrates with terminal glucosamine or galactosamine residues."]
phenotypes_with_genes_Q96L58 = ["Component of the V-ATPase V-ATPase family, which mediates the production of phosphatidylethanolamine (PE) to phosphatidic acid (PA) to produce lysophosphatidylserine (PS) to form adenylylsulfate (Glc) and glycerol (By similarity)."]
phenotypes_only_protein_Q96L58 = ["Phosphatidylserine (PS) phosphatidylserine (PS) and phosphatidylethanolamine (PE), phosphatidylcholine (PC) and phosphatidylethanolamine (PE), phosphatidylserine and phosphatidylethanolamine (PE), phosphatidylserine and sphingosine 1-phosphate (PtdIns(4)P)."]

In [None]:
import numpy as np
'''
Get the BERT Score for a given protein
'''
def get_bertscore(g, p, gt):
  """
  g: phenotype described using gene embeddings
  p: phenotype described without gene embeddings
  gt: ground truth val
  """
  scorer = BERTScorer(lang="en",
                      model_type="allenai/scibert_scivocab_uncased",
                      num_layers=8,
                      rescale_with_baseline=False)

  gene_score = scorer.score(g, gt)
  protein_score = scorer.score(p, gt)

  return gene_score[2], protein_score[2]

#Correct examples:
g_score_Q9UKD2, p_score_Q9UKD2 = get_bertscore(phenotypes_with_genes_Q9UKD2, phenotypes_only_protein_Q9UKD2, ground_truth_Q9UKD2)
g_score_Q96L58, p_score_Q96L58 = get_bertscore(phenotypes_with_genes_Q96L58, phenotypes_only_protein_Q96L58, ground_truth_Q96L58)

print(f"Bert Score for phenotypes of Q9UKD2 with gene embeddings: {g_score_Q9UKD2}")
print(f"Bert Score for phenotyes of Q9UKD2 with ProCyon protein embeddings: {p_score_Q9UKD2}")

print(f"Bert Score for phenotypes of Q96L58 with gene embeddings: {g_score_Q96L58}")
print(f"Bert Score for phenotyes of Q96L58 with ProCyon protein embeddings: {p_score_Q96L58}")

#Procyon does a sample of 1000 for their bootstrap CI so we use that as our default
# However, since we are operating on a smaller scale we only calculate ci on a few samples
'''
Calculates the bootstrap ci for given bertscores
'''
def compute_bootstrap_ci(scores, samples=1000):
  """
    scores: Berstscores
    samples: number of samples to bootstrap
  """
  means = []
  for _ in range(samples):
      sample = np.random.choice(scores, size=len(scores), replace=True)
      means.append(sample.mean())

  lower = np.percentile(means, 2.5)
  upper = np.percentile(means, 97.5)

  return scores.mean(), lower, upper

g_scores = np.array([g_score_Q9UKD2.item(), g_score_Q96L58.item()])
print(f"g_scores: {g_scores}")

p_scores = np.array([p_score_Q9UKD2.item(), p_score_Q96L58.item()])
print(f"p_scores: {p_scores}")

#These should just be the same as the initial BERTScore since we only have 2 proteins to sample from
mean_g, low_g, high_g = compute_bootstrap_ci(g_scores, 1)
print(f"Mean for Gene Model: {mean_g:.6f} (95% CI: {low_g:.6f} - {high_g:.6f})")

mean_p, low_p, high_p = compute_bootstrap_ci(p_scores, 1)
print(f"Mean Protein Model: {mean_p:.6f} (95% CI: {low_p:.6f} - {high_p:.6f})")

In [None]:
from transformers import pipeline
import torch
import gc
import random

model = None
"""
Ensure LLM is only instantiated once per session, not once per judge_phenotype call
"""
def get_pipeline():
    global model
    if model is None:
        model_id = "Qwen/Qwen2.5-7B-Instruct"
        model = pipeline(
            "text-generation",
            model=model_id,
            dtype="auto",
            return_full_text=False,
            temperature=0.01,
            device_map="auto",
        )
    return model
"""
Prevents memory from filling up
"""
def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
"""
LLM as judge approach for rating the biological accuracy/quality of the
"""
def judge_phenotypes(ground_truth, a, b):
    '''
    ground_truth: GT phenotype description
    a: generated phenotype description (ProCyon or Procupine)
    b: generated phenotype description (ProCyon or Procupine)
    '''
    clear_gpu_memory()
    pipe = get_pipeline()

    options = {"a": a, "b": b}
    #Shuffle the options to avoid bias
    keys = list(options.keys())
    random.shuffle(keys)
    shuffled_a = options[keys[0]]
    shuffled_b = options[keys[1]]

    message = f"You are an expert biologist. Compare these two phenotype descriptions against the Ground Truth. Ground Truth: {ground_truth} Option A: {shuffled_a} Option B: {shuffled_b} Which option better captures the specific biological context (tissue, regulation, pathway)? Provide a single word response: ['Option A', 'Option B', 'Tie'] and no additional explination."
    #Limit the number of tokens to prevent the LLM from providing a run on explination.
    outputs = pipe(
        message,
        max_new_tokens=10,
    )

    return outputs[0]['generated_text'].split('\n')[0].strip(' ')


#Test on two proteins
Q9UKD2=judge_phenotypes(ground_truth_Q9UKD2 ,phenotypes_with_genes_Q9UKD2, phenotypes_only_protein_Q9UKD2)
Q96L58=judge_phenotypes(ground_truth_Q96L58 ,phenotypes_with_genes_Q96L58, phenotypes_only_protein_Q96L58)

print("Results for Q9UKD2: " + Q9UKD2)
print("Results for Q96L58: " + Q96L58)