In [1]:
# pip install torch
# pip install tokenizers
# pip install transformers

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from tokenizers import Tokenizer
from models.progen.modeling_progen import ProGenForCausalLM
import os
import pandas as pd
from sklearn.model_selection import train_test_split

import numpy as np
import subprocess
from itertools import combinations
from seq import ab_number as abn

from sklearn.model_selection import KFold


import wandb
wandb.login()

### configure fine-tuning
n_splits = 3
epochs = 10
batch_size = 20
learning_rate = 1e-5
foundation_model_name = 'simple_fine_tuned_progen2-small' #progen2-small, progen2-medium, progen2-large, progen2-xlarge,  simple_fine_tuned_progen2-small

fine_tuning_strategy = None #None, simple_fine_tuned, frozen_layers_tuned, or etc
prompting_strategy = 'prompted' #zero_shot or prompted

if fine_tuning_strategy is None:
    model_name = f'{foundation_model_name}'
else:
    model_name = f'{fine_tuning_strategy}_{foundation_model_name}'
    
experiment_name =f'{model_name}_{prompting_strategy}'
run_description = f'Running {prompting_strategy} across {model_name}'

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="berkeley_antibody_generation",
    entity='antibody_generation',
    name=experiment_name,
    notes=run_description,
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": learning_rate,
    "epochs": epochs,
    "foundational_model": foundation_model_name,
    "model_name": model_name,
    "run_description": run_description,
    "fine_tuning_strategy": fine_tuning_strategy
    }
)

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjoethequant[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mjoethequant[0m ([33mantibody_generation[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Fine Tuning Script

In [2]:
if fine_tuning_strategy == 'simple_fine_tuned':
    # Define a Dataset for loading protein sequences
    class ProteinDataset(Dataset):
        def __init__(self, sequences, tokenizer, begin_token_id, end_token_id):
            self.tokenized_sequences = [tokenizer.encode(sequence, add_special_tokens=False) for sequence in sequences]
            for i, encoding in enumerate(self.tokenized_sequences):
                modified_ids = [begin_token_id] + encoding.ids + [end_token_id]
                self.tokenized_sequences[i] = modified_ids
            
        def __len__(self):
            return len(self.tokenized_sequences)
    
        def __getitem__(self, idx):
            return self.tokenized_sequences[idx]
    
    def collate_fn(batch):
        # Find the max length of sequences in the batch
        max_length = max([len(sequence) for sequence in batch])
        
        # Pad each sequence to the max length and stack them
        padded_input_ids = torch.stack([torch.tensor(sequence + [0]*(max_length - len(sequence))) for sequence in batch])
        
        return {"input_ids": padded_input_ids, "labels": padded_input_ids.clone()}
    
    
    def main(epochs, batch_size, learning_rate, foundation_model_name):
        
        # Load the tokenizer
        tokenizer = Tokenizer.from_file('tokenizer.json')


        # load Model
        model_path = f'./model_checkpoints/{foundation_model_name}'
        
        #initial load of single GPU, will handle multiple GPUs a few lines down.
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Load the pre-trained model
        model = ProGenForCausalLM.from_pretrained(model_path).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
        # Check for multiple GPUs
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs!")
            model = nn.DataParallel(model)
        
        sequences = []
        # Open and read the file
        with open("sabdab_joint_sequences_uniprot.txt", "r") as file:
            [sequences.append(line.strip()) for line in file]

        # Create the KFold object
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # Splitting the data for train, and test.
        train_sequences, test_sequences = train_test_split(sequences, test_size=0.1, random_state=42)  # 70% train, 10% temp
        
        # Initialize test DataLoader for test set. Train and validation are handled in the training loop.
        test_dataset = ProteinDataset(test_sequences, tokenizer, begin_token_id=1, end_token_id=2)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    
        for fold, (train_ids, val_ids) in enumerate(kf.split(sequences)):
            print(f"FOLD {fold}")
            print("--------------------------------")

            # Splitting the data into the current fold
            train_sequences = [sequences[index] for index in train_ids]
            val_sequences = [sequences[index] for index in val_ids]

            # Initialize DataLoaders for the current fold
            train_dataset = ProteinDataset(train_sequences, tokenizer, begin_token_id=1, end_token_id=2)
            val_dataset = ProteinDataset(val_sequences, tokenizer, begin_token_id=1, end_token_id=2)
    
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

            # Fine-tuning Loop
            model.train()
            for epoch in range(epochs):
                total_loss = 0.0
                for batch in train_loader:
                    optimizer.zero_grad()
                    inputs = batch["input_ids"].to(device)
                    labels = batch["labels"].to(device)
                    outputs = model(inputs, labels=labels)
                    loss = outputs.loss
        
                    #ig using multiple GPUs, This ensures that the loss is always a scalar, irrespective of the number of GPUs.
                    if loss.dim() > 0:
                        loss = loss.sum()
                    
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
        
                avg_loss = total_loss / len(train_loader)
                
                # Evaluate on validation set
                model.eval()
                val_loss = 0.0
                with torch.no_grad():
                    for batch in val_loader:
                        inputs = batch["input_ids"].to(device)
                        labels = batch["labels"].to(device)
                        outputs = model(inputs, labels=labels)
                        loss = outputs.loss
                        if loss.dim() > 0:
                            loss = loss.sum()
                        val_loss += loss.item()
                avg_val_loss = val_loss / len(val_loader)
                
                wandb.log({"avg_train_loss": avg_loss, "avg_val_loss": avg_val_loss})
                
                print(f"Epoch: {epoch+1}/{epochs}, Training Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
                
                model.train()

            # Save the fine-tuned model
            if isinstance(model, nn.DataParallel):
                model.module.save_pretrained(f'model_checkpoints/{model_name}')
            else:
                model.save_pretrained(f'model_checkpoints/{model_name}')
        

        # Evaluate on test set
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for batch in test_loader:
                inputs = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                outputs = model(inputs, labels=labels)
                loss = outputs.loss
                if loss.dim() > 0:
                    loss = loss.sum()
                test_loss += loss.item()
        avg_test_loss = test_loss / len(test_loader)
        wandb.log({"test_loss": avg_test_loss})
        print(f"Test Loss: {avg_test_loss:.4f}")

    if __name__ == '__main__':
        main(epochs, batch_size, learning_rate, foundation_model_name)

# Final Sampling Results and Scoring

In [3]:
# make sure to add ANARCI to Path
os.environ['PATH'] = '/root/miniconda3/envs/anarci/bin:' + os.environ['PATH']

def run_anarci(sequence):
    # Run ANARCI as a subprocess

    result = subprocess.run(['ANARCI', '--sequence', sequence, '--scheme', 'aho'], capture_output=True, text=True)

    sequence_results = result.stdout.split('\n')

    species = None
    e_value = None
    score = None
    heavy_chain = np.array([])
    light_chain = np.array([])

    try: #push this into try as we do not want to stop the program if ANARCI fails. If it fails, it will return an empty arrays and thus not be included in the anarci results and data files.
        if len(sequence_results) > 4:

            blank, species, chain_type, e_value, score, seqstart_index, seqend_index, blank_2 = sequence_results[5].split('|')

            h_seq = []
            l_seq = []
            for row in sequence_results[7:]:
                row = [x for x in row.split(' ') if x != '']
                if (len(row) == 3) and (row[0] == 'H'):       
                    h_seq.append(row[2])
                elif (len(row) == 3) and (row[0] == 'L'):
                    l_seq.append(row[2])

            heavy_chain = np.array(h_seq)
            light_chain = np.array(l_seq)

    except:
        pass
    
    return species, e_value, score, heavy_chain, light_chain


def predict_sequence(model, tokenizer, sequence, device='cuda:0', number_of_sequences=1 ):
    # Tokenize the sequence
    tokenized_sequence = tokenizer.encode(sequence)
    
    # Convert to PyTorch tensor and add batch dimension
    input_tensor = torch.tensor([tokenized_sequence.ids]).to(device)
    
    # Pass the tensor through the model
    with torch.no_grad():
        output = model.generate(input_tensor, max_length=1024, pad_token_id=tokenizer.encode('<|pad|>').ids[0], do_sample=True, top_p=0.9, temperature=0.8, num_return_sequences=number_of_sequences)

        as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]
        sequences = tokenizer.decode_batch(as_lists(output))

        if len(sequences) > 0:
            sequences = [x.replace('2', '') for x in sequences] #replace stop token with empty string
        else:
            return []

        sequence_with_heavy_and_light_chains = []

        #filter out sequences that don't have heavy and light chains
        for sequence in sequences:
            # print(sequence)
            species, e_value, score, heavy_chain, light_chain = run_anarci(sequence)
            if (len(heavy_chain) > 0) and (len(light_chain) > 0):
                sequence_with_heavy_and_light_chains.append(sequence)

        return sequence_with_heavy_and_light_chains


def percent_identity(seq1, seq2):
    """ Compute the percent identity of two strings of equal length. """
    if len(seq1) != len(seq2):
        raise ValueError('Sequences must be the same length.')
    i = 0
    for r1, r2 in zip(seq1, seq2):
        i += int(r1 == r2)
    return i * 100 / len(seq1)

def full_seq_identity(df_anarci_H, df_anarci_KL):
    df_anarci_H = df_anarci_H.copy().set_index('Id')
    df_anarci_KL = df_anarci_KL.copy().set_index('Id')
    df_anarci_H['full_seq_H'] = df_anarci_H.loc[:, '1':].apply(lambda x: ''.join(x), axis=1)
    df_anarci_KL['full_seq_KL'] = df_anarci_KL.loc[:, '1':].apply(lambda x: ''.join(x), axis=1)
    df_anarci = df_anarci_H.merge(df_anarci_KL, left_index=True, right_index=True)
    seqs = [x['full_seq_H'] + x['full_seq_KL'] for _, x in df_anarci.iterrows()]

    identity_dist = [percent_identity(s[0], s[1]) for s in combinations(seqs, 2)]
    return sum(identity_dist) / len(identity_dist)

def cdr3_seq_identity(df_anarci_H):
    df_seqs = df_anarci_H.loc[:, '105':'117']
    seqs = [''.join(x) for _, x in df_seqs.iterrows()]
    identity_dist = [percent_identity(s[0], s[1]) for s in combinations(seqs, 2)]
    return sum(identity_dist) / len(identity_dist)

def diversity_metrics(anarci_csv_path_H, anarci_csv_path_KL):
    df_anarci_H = pd.read_csv(anarci_csv_path_H)
    df_anarci_KL = pd.read_csv(anarci_csv_path_KL)
    results = {'avg_seq_identity_full': full_seq_identity(df_anarci_H, df_anarci_KL),
               'avg_seq_identity_cdr3': cdr3_seq_identity(df_anarci_H)}
    return results

In [4]:
%%time

model_path = f'model_checkpoints/{model_name}'
device = 'cuda:0'  # Define the device variable outside the if-else condition

# Initialize the model first
model = ProGenForCausalLM.from_pretrained(model_path).to(device)

# Check if multiple GPUs are available and use ProGen's parallelization
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model.parallelize() # ProGen's parallelize method
else:
    print(f'Device: {device}')

tokenizer = Tokenizer.from_file('tokenizer.json')

number_of_sequences = 10
start_of_antibody_sequence_prompter = 'EVQLVESGGGLVQPGGSLRLSC'

targets = [
    { "sequence_id": 'PD1',
      "sequence": 'MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGS'
    },
    { "sequence_id": 'SARS-CoV2',
      "sequence": 'RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGS'
    },
    { "sequence_id": 'vWF',
      "sequence": 'DLVFLLDGSSRLSEAEFEVLKAFVVDMMERLRISQKWVRVAVVEYHDGSHAYIGLKDRKRPSELRRIASQVKYAGSQVASTSEVLKYTLFQIFSKIDRPEASRITLLLMASQEPQRMSRNFVRYVQGLKKKKVIVIPVGIGPHANLKQIRLIEKQAPENKAFVLSSVDELEQQRDEIGGGGGSGGGGSGGGGS'
    }
]

total_number_of_sequences_requested = number_of_sequences * len(targets)
total_number_of_sequences_returned = 0
total_percent_sequences_returned = 0
total_df_result_H = pd.DataFrame()
total_df_result_KL = pd.DataFrame()

total_avg_seq_identity_full = []
total_avg_seq_identity_cdr3 = []

for target in targets[:]:

    target_sequence_id = target['sequence_id']
    target_sequence = target['sequence']
    experiment_target_id = f'{experiment_name}_{target_sequence_id}'
    
    if prompting_strategy == 'prompted':
        target_sequence = target_sequence + start_of_antibody_sequence_prompter

    sampled_sequences = predict_sequence(model, tokenizer, target_sequence, device, number_of_sequences=number_of_sequences)

    df_result_H, df_result_KL = abn.number_seqs_as_df(sampled_sequences)

    if (df_result_H is not None) and (len(df_result_H) > 0):
        df_result_H['model_id'] = f'{experiment_target_id}_H'
        df_result_H.to_csv(f'./results/{experiment_target_id}_H.csv')
        mean_anarci_h_score = df_result_H['score'].mean()
        min_anarci_h_score = df_result_H['score'].min()
        total_df_result_H = pd.concat([total_df_result_H, df_result_H], axis=0)
    else:
        mean_anarci_h_score = 0
        min_anarci_h_score = 0
    
    if (df_result_KL is not None) and (len(df_result_KL) > 0):
        df_result_KL['model_id'] = f'{experiment_target_id}_KL'
        df_result_KL.to_csv(f'./results/{experiment_target_id}_KL.csv')
        mean_anarci_kl_score = df_result_KL['score'].mean()
        min_anarci_kl_score = df_result_KL['score'].min()
        total_df_result_KL = pd.concat([total_df_result_KL, df_result_KL], axis=0)
    else:
        mean_anarci_kl_score = 0
        min_anarci_kl_score = 0

    if (df_result_H is not None) and (len(df_result_H) > 0) and (df_result_KL is not None) and (len(df_result_KL) > 0):
        try:
            diversity_metrics_dict = diversity_metrics(f'./results/{experiment_target_id}_H.csv', f'./results/{experiment_target_id}_KL.csv')
            avg_seq_identity_full = diversity_metrics_dict['avg_seq_identity_full']
            avg_seq_identity_cdr3 = diversity_metrics_dict['avg_seq_identity_cdr3']
        except:
            avg_seq_identity_full = 0.0
            avg_seq_identity_cdr3 = 0.0

    else:
        avg_seq_identity_full = 0.0
        avg_seq_identity_cdr3 = 0.0

    number_of_sequences_returned = len(sampled_sequences)
    total_number_of_sequences_returned += number_of_sequences_returned    
    total_avg_seq_identity_full.append(avg_seq_identity_full)
    total_avg_seq_identity_cdr3.append(avg_seq_identity_cdr3)
    
    wandb.log({
        f"{target_sequence_id}_number_of_sequences_requested": number_of_sequences, 
        f"{target_sequence_id}_number_of_sequences_returned": number_of_sequences_returned, 
        f"{target_sequence_id}_percent_sequences_returned": number_of_sequences_returned/number_of_sequences,
        f"{target_sequence_id}_mean_anarci_h_score": mean_anarci_h_score,
        f"{target_sequence_id}_min_anarci_h_score": min_anarci_h_score,
        f"{target_sequence_id}_mean_anarci_kl_score": mean_anarci_kl_score,
        f"{target_sequence_id}_min_anarci_kl_score": min_anarci_kl_score,

        f"{target_sequence_id}_avg_seq_identity_full": avg_seq_identity_full,
        f"{target_sequence_id}_avg_seq_identity_cdr3": avg_seq_identity_cdr3,
    })

    print(f'Total Sequences Asked For: {number_of_sequences}, Total Sequences Returned: {len(sampled_sequences)}, Percent Returned: {len(sampled_sequences)/number_of_sequences}')

if len(total_df_result_H) > 0:
    total_mean_anarci_h_score = total_df_result_H['score'].mean()
    total_min_anarci_h_score = total_df_result_H['score'].min()
else:
    total_mean_anarci_h_score = 0
    total_min_anarci_h_score = 0
    
if len(total_df_result_KL) > 0:
    total_mean_anarci_kl_score = total_df_result_KL['score'].mean()
    total_min_anarci_kl_score = total_df_result_KL['score'].min()
else:
    total_mean_anarci_kl_score = 0
    total_min_anarci_kl_score = 0
    
wandb.log({
        f"total_number_of_sequences_requested": total_number_of_sequences_requested, 
        f"total_number_of_sequences_returned": total_number_of_sequences_returned, 
        f"total_percent_sequences_returned": total_number_of_sequences_returned/total_number_of_sequences_requested,
        f"total_mean_anarci_h_score": total_mean_anarci_h_score,
        f"total_min_anarci_h_score": total_min_anarci_h_score,
        f"total_mean_anarci_kl_score": total_mean_anarci_kl_score,
        f"total_min_anarci_kl_score": total_min_anarci_kl_score,
        f"mean_of_avg_seq_identity_full": np.mean(total_avg_seq_identity_full),
        f"mean_of_avg_seq_identity_cdr3": np.mean(total_avg_seq_identity_cdr3),
    })

wandb.finish()

Using 6 GPUs!
Total Sequences Asked For: 10, Total Sequences Returned: 10, Percent Returned: 1.0
Total Sequences Asked For: 10, Total Sequences Returned: 9, Percent Returned: 0.9
Total Sequences Asked For: 10, Total Sequences Returned: 8, Percent Returned: 0.8




0,1
PD1_avg_seq_identity_cdr3,▁
PD1_avg_seq_identity_full,▁
PD1_mean_anarci_h_score,▁
PD1_mean_anarci_kl_score,▁
PD1_min_anarci_h_score,▁
PD1_min_anarci_kl_score,▁
PD1_number_of_sequences_requested,▁
PD1_number_of_sequences_returned,▁
PD1_percent_sequences_returned,▁
SARS-CoV2_avg_seq_identity_cdr3,▁

0,1
PD1_avg_seq_identity_cdr3,51.23457
PD1_avg_seq_identity_full,84.4164
PD1_mean_anarci_h_score,181.64
PD1_mean_anarci_kl_score,177.17
PD1_min_anarci_h_score,174.9
PD1_min_anarci_kl_score,160.6
PD1_number_of_sequences_requested,10.0
PD1_number_of_sequences_returned,10.0
PD1_percent_sequences_returned,1.0
SARS-CoV2_avg_seq_identity_cdr3,49.53704


CPU times: user 7min 20s, sys: 12.4 s, total: 7min 33s
Wall time: 5min 25s


In [5]:
# runs.summary.map((row, index) => {Run_name: row.run.name, Train Loss: row["avg_train_loss"], Validation Loss: row["avg_val_loss"], Test Loss: row["test_loss"], Percent Valid Sequences Returned: row["total_percent_sequences_returned"], "H Chain: Mean ANARCI Score": row['total_mean_anarci_h_score'], "KL Chain: Mean ANARCI Score": row['total_mean_anarci_kl_score']})
# runs.summary.map((row, index) => {Run_name: row.run.name, "Train Loss": row["avg_train_loss"], "Validation Loss": row["avg_val_loss"], "Test Loss": row["test_loss"], "Percent Valid Sequences Returned": row["total_percent_sequences_returned"]})

In [6]:
# runs.summary.map((row, index) => {Run_name: row.run.name, Train_Loss: row["avg_train_loss"], Validation_Loss: row["avg_val_loss"], Test_Loss: row["test_loss"], Percent_Valid_Sequences_Returned: row["total_percent_sequences_returned"], HChain_Mean_ANARCI_Score: row['total_mean_anarci_h_score'], KLChain_Mean_ANARCI_Score: row['total_mean_anarci_kl_score']})