In [1]:
# pip install torch
# pip install tokenizers
# pip install transformers

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tokenizers import Tokenizer
from models.progen.modeling_progen import ProGenForCausalLM
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import logging

import numpy as np
import subprocess
from itertools import combinations
from seq import ab_number as abn

from sklearn.model_selection import KFold

import wandb
wandb.login()

# Set logging level to ERROR to suppress the download progress updates
logging.getLogger('transformers').setLevel(logging.ERROR)

### configure fine-tunings
n_splits = 3
epochs = 10
batch_size = 5
learning_rate = 1e-5
foundation_model_name = 'fine-tuned-progen2-large' #model id/name in huggingface. Will be downloaded from Huggingface.
model_name = foundation_model_name

fine_tuning_strategy = None
prompting_strategy = 'prompted' #zero_shot or prompted

experiment_name =f'{model_name}_{prompting_strategy}'
run_description = f'Running {prompting_strategy} across {model_name}'

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="berkeley_antibody_generation",
    entity='antibody_generation',
    name=experiment_name,
    notes=run_description,
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": learning_rate,
    "epochs": epochs,
    "foundational_model": foundation_model_name,
    "model_name": model_name,
    "run_description": run_description,
    "fine_tuning_strategy": fine_tuning_strategy
    }
)

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjoethequant[0m ([33mantibody_generation[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Final Sampling Results and Scoring

In [2]:
# make sure to add ANARCI to Path
os.environ['PATH'] = '/root/miniconda3/envs/anarci/bin:' + os.environ['PATH']

def run_anarci(sequence):
    # Run ANARCI as a subprocess

    result = subprocess.run(['ANARCI', '--sequence', sequence, '--scheme', 'aho'], capture_output=True, text=True)

    sequence_results = result.stdout.split('\n')

    species = None
    e_value = None
    score = None
    heavy_chain = np.array([])
    light_chain = np.array([])

    try: #push this into try as we do not want to stop the program if ANARCI fails. If it fails, it will return an empty arrays and thus not be included in the anarci results and data files.
        if len(sequence_results) > 4:

            blank, species, chain_type, e_value, score, seqstart_index, seqend_index, blank_2 = sequence_results[5].split('|')

            h_seq = []
            l_seq = []
            for row in sequence_results[7:]:
                row = [x for x in row.split(' ') if x != '']
                if (len(row) == 3) and (row[0] == 'H'):       
                    h_seq.append(row[2])
                elif (len(row) == 3) and (row[0] == 'L'):
                    l_seq.append(row[2])

            heavy_chain = np.array(h_seq)
            light_chain = np.array(l_seq)

    except:
        pass
    
    return species, e_value, score, heavy_chain, light_chain


def predict_sequence(model, tokenizer, sequence, device='cuda:0', number_of_sequences=1 ):
    # Tokenize the sequence
    tokenized_sequence = tokenizer.encode(sequence)
    
    # Convert to PyTorch tensor and add batch dimension
    input_tensor = torch.tensor([tokenized_sequence.ids]).to(device)
    
    # Pass the tensor through the model
    with torch.no_grad():
        output = model.generate(input_tensor, max_length=1024, pad_token_id=tokenizer.encode('<|pad|>').ids[0], do_sample=True, top_p=0.9, temperature=0.8, num_return_sequences=number_of_sequences)

        as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]
        sequences = tokenizer.decode_batch(as_lists(output))

        if len(sequences) > 0:
            sequences = [x.replace('2', '') for x in sequences] #replace stop token with empty string
        else:
            return []

        sequence_with_heavy_and_light_chains = []

        #filter out sequences that don't have heavy and light chains
        for sequence in sequences:
            # print(sequence)
            species, e_value, score, heavy_chain, light_chain = run_anarci(sequence)
            if (len(heavy_chain) > 0) and (len(light_chain) > 0):
                sequence_with_heavy_and_light_chains.append(sequence)

        return sequence_with_heavy_and_light_chains


def percent_identity(seq1, seq2):
    """ Compute the percent identity of two strings of equal length. """
    if len(seq1) != len(seq2):
        raise ValueError('Sequences must be the same length.')
    i = 0
    for r1, r2 in zip(seq1, seq2):
        i += int(r1 == r2)
    return i * 100 / len(seq1)

def full_seq_identity(df_anarci_H, df_anarci_KL):
    df_anarci_H = df_anarci_H.copy().set_index('Id')
    df_anarci_KL = df_anarci_KL.copy().set_index('Id')
    df_anarci_H['full_seq_H'] = df_anarci_H.loc[:, '1':].apply(lambda x: ''.join(x), axis=1)
    df_anarci_KL['full_seq_KL'] = df_anarci_KL.loc[:, '1':].apply(lambda x: ''.join(x), axis=1)
    df_anarci = df_anarci_H.merge(df_anarci_KL, left_index=True, right_index=True)
    seqs = [x['full_seq_H'] + x['full_seq_KL'] for _, x in df_anarci.iterrows()]

    identity_dist = [percent_identity(s[0], s[1]) for s in combinations(seqs, 2)]
    return sum(identity_dist) / len(identity_dist)

def cdr3_seq_identity(df_anarci_H):
    df_seqs = df_anarci_H.loc[:, '105':'117']
    seqs = [''.join(x) for _, x in df_seqs.iterrows()]
    identity_dist = [percent_identity(s[0], s[1]) for s in combinations(seqs, 2)]
    return sum(identity_dist) / len(identity_dist)

def diversity_metrics(anarci_csv_path_H, anarci_csv_path_KL):
    df_anarci_H = pd.read_csv(anarci_csv_path_H)
    df_anarci_KL = pd.read_csv(anarci_csv_path_KL)
    results = {'avg_seq_identity_full': full_seq_identity(df_anarci_H, df_anarci_KL),
               'avg_seq_identity_cdr3': cdr3_seq_identity(df_anarci_H)}
    return results

In [3]:
%%time

model_path = f'AntibodyGeneration/{model_name}'
device = 'cuda:0'  # Define the device variable outside the if-else condition

# Initialize the model first
model = ProGenForCausalLM.from_pretrained(model_path).to(device)

# Check if multiple GPUs are available and use ProGen's parallelization
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model.parallelize() # ProGen's parallelize method
else:
    print(f'Device: {device}')

tokenizer = Tokenizer.from_file('tokenizer.json')

number_of_sequences = 10
start_of_antibody_sequence_prompter = 'EVQLVESGGGLVQPGGSLRLSC'

targets = [
    { "sequence_id": 'PD1',
      "sequence": 'MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGS'
    },
    { "sequence_id": 'SARS-CoV2',
      "sequence": 'RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGS'
    },
    { "sequence_id": 'vWF',
      "sequence": 'DLVFLLDGSSRLSEAEFEVLKAFVVDMMERLRISQKWVRVAVVEYHDGSHAYIGLKDRKRPSELRRIASQVKYAGSQVASTSEVLKYTLFQIFSKIDRPEASRITLLLMASQEPQRMSRNFVRYVQGLKKKKVIVIPVGIGPHANLKQIRLIEKQAPENKAFVLSSVDELEQQRDEIGGGGGSGGGGSGGGGS'
    }
]

total_number_of_sequences_requested = number_of_sequences * len(targets)
total_number_of_sequences_returned = 0
total_percent_sequences_returned = 0
total_df_result_H = pd.DataFrame()
total_df_result_KL = pd.DataFrame()

total_avg_seq_identity_full = []
total_avg_seq_identity_cdr3 = []

for target in targets[:]:

    target_sequence_id = target['sequence_id']
    target_sequence = target['sequence']
    experiment_target_id = f'{experiment_name}_{target_sequence_id}'
    
    if prompting_strategy == 'prompted':
        target_sequence = target_sequence + start_of_antibody_sequence_prompter

    sampled_sequences = predict_sequence(model, tokenizer, target_sequence, device, number_of_sequences=number_of_sequences)

    df_result_H, df_result_KL = abn.number_seqs_as_df(sampled_sequences)

    if (df_result_H is not None) and (len(df_result_H) > 0):
        df_result_H['model_id'] = f'{experiment_target_id}_H'
        df_result_H.to_csv(f'./results/{experiment_target_id}_H.csv')
        mean_anarci_h_score = df_result_H['score'].mean()
        min_anarci_h_score = df_result_H['score'].min()
        total_df_result_H = pd.concat([total_df_result_H, df_result_H], axis=0)
    else:
        mean_anarci_h_score = 0
        min_anarci_h_score = 0
    
    if (df_result_KL is not None) and (len(df_result_KL) > 0):
        df_result_KL['model_id'] = f'{experiment_target_id}_KL'
        df_result_KL.to_csv(f'./results/{experiment_target_id}_KL.csv')
        mean_anarci_kl_score = df_result_KL['score'].mean()
        min_anarci_kl_score = df_result_KL['score'].min()
        total_df_result_KL = pd.concat([total_df_result_KL, df_result_KL], axis=0)
    else:
        mean_anarci_kl_score = 0
        min_anarci_kl_score = 0

    if (df_result_H is not None) and (len(df_result_H) > 0) and (df_result_KL is not None) and (len(df_result_KL) > 0):
        try:
            diversity_metrics_dict = diversity_metrics(f'./results/{experiment_target_id}_H.csv', f'./results/{experiment_target_id}_KL.csv')
            avg_seq_identity_full = diversity_metrics_dict['avg_seq_identity_full']
            avg_seq_identity_cdr3 = diversity_metrics_dict['avg_seq_identity_cdr3']
        except:
            avg_seq_identity_full = 0.0
            avg_seq_identity_cdr3 = 0.0

    else:
        avg_seq_identity_full = 0.0
        avg_seq_identity_cdr3 = 0.0

    number_of_sequences_returned = len(sampled_sequences)
    total_number_of_sequences_returned += number_of_sequences_returned    
    total_avg_seq_identity_full.append(avg_seq_identity_full)
    total_avg_seq_identity_cdr3.append(avg_seq_identity_cdr3)
    
    wandb.log({
        f"{target_sequence_id}_number_of_sequences_requested": number_of_sequences, 
        f"{target_sequence_id}_number_of_sequences_returned": number_of_sequences_returned, 
        f"{target_sequence_id}_percent_sequences_returned": number_of_sequences_returned/number_of_sequences,
        f"{target_sequence_id}_mean_anarci_h_score": mean_anarci_h_score,
        f"{target_sequence_id}_min_anarci_h_score": min_anarci_h_score,
        f"{target_sequence_id}_mean_anarci_kl_score": mean_anarci_kl_score,
        f"{target_sequence_id}_min_anarci_kl_score": min_anarci_kl_score,

        f"{target_sequence_id}_avg_seq_identity_full": avg_seq_identity_full,
        f"{target_sequence_id}_avg_seq_identity_cdr3": avg_seq_identity_cdr3,
    })

    print(f'Total Sequences Asked For: {number_of_sequences}, Total Sequences Returned: {len(sampled_sequences)}, Percent Returned: {len(sampled_sequences)/number_of_sequences}')

if len(total_df_result_H) > 0:
    total_mean_anarci_h_score = total_df_result_H['score'].mean()
    total_min_anarci_h_score = total_df_result_H['score'].min()
else:
    total_mean_anarci_h_score = 0
    total_min_anarci_h_score = 0
    
if len(total_df_result_KL) > 0:
    total_mean_anarci_kl_score = total_df_result_KL['score'].mean()
    total_min_anarci_kl_score = total_df_result_KL['score'].min()
else:
    total_mean_anarci_kl_score = 0
    total_min_anarci_kl_score = 0
    
wandb.log({
        f"total_number_of_sequences_requested": total_number_of_sequences_requested, 
        f"total_number_of_sequences_returned": total_number_of_sequences_returned, 
        f"total_percent_sequences_returned": total_number_of_sequences_returned/total_number_of_sequences_requested,
        f"total_mean_anarci_h_score": total_mean_anarci_h_score,
        f"total_min_anarci_h_score": total_min_anarci_h_score,
        f"total_mean_anarci_kl_score": total_mean_anarci_kl_score,
        f"total_min_anarci_kl_score": total_min_anarci_kl_score,
        f"mean_of_avg_seq_identity_full": np.mean(total_avg_seq_identity_full),
        f"mean_of_avg_seq_identity_cdr3": np.mean(total_avg_seq_identity_cdr3),
    })

wandb.finish()

config.json: 100%|██████████| 972/972 [00:00<00:00, 2.34MB/s]
model.safetensors.index.json: 100%|██████████| 24.5k/24.5k [00:00<00:00, 13.4MB/s]
Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]
model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s][A
model-00001-of-00003.safetensors:   0%|          | 10.5M/4.95G [00:00<04:41, 17.5MB/s][A
model-00001-of-00003.safetensors:   0%|          | 21.0M/4.95G [00:01<04:01, 20.4MB/s][A
model-00001-of-00003.safetensors:   1%|          | 31.5M/4.95G [00:01<04:36, 17.8MB/s][A
model-00001-of-00003.safetensors:   1%|          | 41.9M/4.95G [00:02<05:20, 15.3MB/s][A
model-00001-of-00003.safetensors:   1%|          | 52.4M/4.95G [00:03<06:05, 13.4MB/s][A
model-00001-of-00003.safetensors:   1%|▏         | 62.9M/4.95G [00:04<06:53, 11.8MB/s][A
model-00001-of-00003.safetensors:   1%|▏         | 73.4M/4.95G [00:05<07:14, 11.2MB/s][A
model-00001-of-00003.safetensors:   2%|▏         | 83.9M/4.95G [00:06<07:19, 11.1MB/s]

Using 2 GPUs!
Total Sequences Asked For: 10, Total Sequences Returned: 7, Percent Returned: 0.7
Total Sequences Asked For: 10, Total Sequences Returned: 10, Percent Returned: 1.0
Total Sequences Asked For: 10, Total Sequences Returned: 10, Percent Returned: 1.0




0,1
PD1_avg_seq_identity_cdr3,▁
PD1_avg_seq_identity_full,▁
PD1_mean_anarci_h_score,▁
PD1_mean_anarci_kl_score,▁
PD1_min_anarci_h_score,▁
PD1_min_anarci_kl_score,▁
PD1_number_of_sequences_requested,▁
PD1_number_of_sequences_returned,▁
PD1_percent_sequences_returned,▁
SARS-CoV2_avg_seq_identity_cdr3,▁

0,1
PD1_avg_seq_identity_cdr3,34.46712
PD1_avg_seq_identity_full,82.74854
PD1_mean_anarci_h_score,184.7
PD1_mean_anarci_kl_score,174.15714
PD1_min_anarci_h_score,164.9
PD1_min_anarci_kl_score,158.2
PD1_number_of_sequences_requested,10.0
PD1_number_of_sequences_returned,7.0
PD1_percent_sequences_returned,0.7
SARS-CoV2_avg_seq_identity_cdr3,61.88034


CPU times: user 20min 30s, sys: 46min 34s, total: 1h 7min 4s
Wall time: 30min 6s


In [None]:
# runs.summary.map((row, index) => {Run_name: row.run.name, Train Loss: row["avg_train_loss"], Validation Loss: row["avg_val_loss"], Test Loss: row["test_loss"], Percent Valid Sequences Returned: row["total_percent_sequences_returned"], "H Chain: Mean ANARCI Score": row['total_mean_anarci_h_score'], "KL Chain: Mean ANARCI Score": row['total_mean_anarci_kl_score']})
# runs.summary.map((row, index) => {Run_name: row.run.name, "Train Loss": row["avg_train_loss"], "Validation Loss": row["avg_val_loss"], "Test Loss": row["test_loss"], "Percent Valid Sequences Returned": row["total_percent_sequences_returned"]})

In [None]:
# runs.summary.map((row, index) => {Run_name: row.run.name, Train_Loss: row["avg_train_loss"], Validation_Loss: row["avg_val_loss"], Test_Loss: row["test_loss"], Percent_Valid_Sequences_Returned: row["total_percent_sequences_returned"], HChain_Mean_ANARCI_Score: row['total_mean_anarci_h_score'], KLChain_Mean_ANARCI_Score: row['total_mean_anarci_kl_score']})