In [1]:
# pip install torch
# pip install tokenizers
# pip install transformers

import torch
from torch.utils.data import DataLoader, Dataset
from tokenizers import Tokenizer
from models.progen.modeling_progen import ProGenForCausalLM
import os
from seq import ab_number as abn
import numpy as np
import subprocess



# make sure to add ANARCI to Path
os.environ['PATH'] = '/root/miniconda3/envs/anarci/bin:' + os.environ['PATH']

def run_anarci(sequence):
    # Run ANARCI as a subprocess

    result = subprocess.run(['ANARCI', '--sequence', sequence, '--scheme', 'aho'], capture_output=True, text=True)

    sequence_results = result.stdout.split('\n')

    species = None
    e_value = None
    score = None
    heavy_chain = np.array([])
    light_chain = np.array([])

    try: #push this into try as we do not want to stop the program if ANARCI fails. If it fails, it will return an empty arrays and thus not be included in the anarci results and data files.
        if len(sequence_results) > 4:

            blank, species, chain_type, e_value, score, seqstart_index, seqend_index, blank_2 = sequence_results[5].split('|')

            h_seq = []
            l_seq = []
            for row in sequence_results[7:]:
                row = [x for x in row.split(' ') if x != '']
                if (len(row) == 3) and (row[0] == 'H'):       
                    h_seq.append(row[2])
                elif (len(row) == 3) and (row[0] == 'L'):
                    l_seq.append(row[2])

            heavy_chain = np.array(h_seq)
            light_chain = np.array(l_seq)

    except:
        pass
    
    return species, e_value, score, heavy_chain, light_chain


def predict_sequence(model, tokenizer, sequence, device='cuda:0', number_of_sequences=1 ):
    # Tokenize the sequence
    tokenized_sequence = tokenizer.encode(sequence)
    
    # Convert to PyTorch tensor and add batch dimension
    input_tensor = torch.tensor([tokenized_sequence.ids]).to(device)
    
    # Pass the tensor through the model
    with torch.no_grad():
        output = model.generate(input_tensor, max_length=1024, pad_token_id=tokenizer.encode('<|pad|>').ids[0], do_sample=True, top_p=0.9, temperature=0.8, num_return_sequences=number_of_sequences)

        as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]
        sequences = tokenizer.decode_batch(as_lists(output))

        if len(sequences) > 0:
            sequences = [x.replace('2', '') for x in sequences] #replace stop token with empty string
        else:
            return []

        sequence_with_heavy_and_light_chains = []

        #filter out sequences that don't have heavy and light chains
        for sequence in sequences:
            print(sequence)
            species, e_value, score, heavy_chain, light_chain = run_anarci(sequence)
            if (len(heavy_chain) > 0) and (len(light_chain) > 0):
                sequence_with_heavy_and_light_chains.append(sequence)

        return sequence_with_heavy_and_light_chains

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# sequence = 'RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGSGGGGSGGGGSNILLSQPPLIEKLEGKKATFSVKAGDNVLINGFIVRGTQAKKVIIRAIGPSLTAFGVTDALADPTLELHDGTGALIASNDNWQTTIIGGIITHDQVQEIQDSGHAPGDGRESAIIADLPPGNYTAIVRGVNSTIGVALVEVYDLSPDANSILGNISTRSFVQTGDNVMIGGFIVQGTQPKRVIIRAIGPELSQYGVPDALANPTLELHDGSGALIGSNDNWQHTIIGGIITSDQVQDIQNSGHAPGDGRESAIIANLPPGNYTAIVRGVNSTTGVALVEVYDLSPGASSTLGNISTRSFVQTGDNVMIGGFIVQGTQPKRVIIRAIGPELSQYGVPDALADPTLELHDGTGALIASNDNWQHTIIGGIITSDQVQDIQNSGHAPGDGRESAIIADLPPGNYTAIVRGVNSTTGVALVEVYDLSPGASSTLGNISTRSFVQTGDNVMIGGFIVQGTQPKRVIIRAIGPELSQYGVPDALADPTLELHDGTGALIASNDNWQHTIIGGIITSDQVQDIQNSGHAPGDGRESAIIADLPPGNYTAIVRGVNSTIGVALVEVYDLSPGASSTLGNISTRSFVQTGDNVMIGGFIVQGTQPKRVIIRAIGPELSQYGVPDALADPTLELHDGSGALIASNDNWQHTIIGGIITSDQVQDIQNSGHAPGDGRESAIIADLPPGNYTAIVRGVNSTTGVALVEVYDLSPGASSTLGNISTRSFVQTGDNVMIGGFIVQGTQPKRVIIRAIGPELSQYGVPDALADPTLELHDGSGALIASND'
# result = subprocess.run(['ANARCI', '--sequence', sequence, '--scheme', 'aho'], capture_output=True, text=True)

# sequence_results = result.stdout.split('\n')
# sequence_results
# # len(sequence_results)

In [3]:
# len(target_sequence)

In [10]:
%%time

# model_path = './model_checkpoints/fine_tuned_progen2-small'
# model_path = './model_checkpoints/progen2-xlarge'
model_path = './model_checkpoints/fine_tuned_progen2-large'
device = 'cuda:0'  # Define the device variable outside the if-else condition

# Initialize the model first
model = ProGenForCausalLM.from_pretrained(model_path).to(device)

# Check if multiple GPUs are available and use ProGen's parallelization
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model.parallelize() # ProGen's parallelize method
else:
    print(f'Device: {device}')

tokenizer = Tokenizer.from_file('tokenizer.json')

model_id = 'fine_tuned_large_zeroshot'

#Small start of antibody sequence for testing 
# target_id = 'TEST'
# target_sequence = '''EVQLVESGGGLVQPGGSLRLSC'''  

start_of_antibody_sequence = 'EVQLVESGGGLVQPGGSLRLSC'

# target_id = 'PD1'
# target_sequence = 'MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGS'

target_id = 'SARS-CoV2'
target_sequence = 'RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGS'

# target_id = 'vWF'     
# target_sequence = 'DLVFLLDGSSRLSEAEFEVLKAFVVDMMERLRISQKWVRVAVVEYHDGSHAYIGLKDRKRPSELRRIASQVKYAGSQVASTSEVLKYTLFQIFSKIDRPEASRITLLLMASQEPQRMSRNFVRYVQGLKKKKVIVIPVGIGPHANLKQIRLIEKQAPENKAFVLSSVDELEQQRDEIGGGGGSGGGGSGGGGS'

target_sequence = target_sequence# + start_of_antibody_sequence

number_of_sequences = 3

sequences = predict_sequence(model, tokenizer, target_sequence, device, number_of_sequences=number_of_sequences)

df_result_H, df_result_KL = abn.number_seqs_as_df(sequences)

if (df_result_H is not None) and (len(df_result_H) > 0):
    df_result_H['model_id'] = f'{model_id}_{target_id}_H'
    df_result_H.to_csv(f'./results/{model_id}_{target_id}_H.csv')

if (df_result_KL is not None) and (len(df_result_KL) > 0):
    df_result_KL['model_id'] = f'{model_id}_{target_id}_KL'
    df_result_KL.to_csv(f'./results/{model_id}_{target_id}_KL.csv')

print(f'Total Sequences Asked For: {number_of_sequences}, Total Sequences Returned: {len(sequences)}, Percent Returned: {len(sequences)/number_of_sequences}')

Loading checkpoint shards: 100%|██████████| 2/2 [00:17<00:00,  8.57s/it]


Device: cuda:0
RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGSQVQLVQSGAEVKKPGASVKVSCKASGYTFTSYAIHWVRQAPGQRLEWMGWIKAGNGNTRYSQKFQGRVTITRDTSASTAYMELSSLRSEDTAVYYCALLTVITPDDAFDIWGQGTMVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKSCGGGGSGGGGSGGGGSDIQLTQSPDSLAVSLGERATINCKSSQSVLYSSINKNYLAWYQQKPGQPPKLLIYWASTRESGVPDRFSGSGSGTDFTLTISSLQAEDVAVYYCQQYYSTPLTFGGGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC
RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFGGGGGSGGGGSGGGGSEVQLVESGGGLIQPGGSLRLSCAASGITVSSNYMSWVRQAPGKGLEW

In [5]:
target_sequence

'MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGS'

In [6]:
df_result_H

Unnamed: 0,Id,domain_no,hmm_species,chain_type,e-value,score,seqstart_index,seqend_index,identity_species,v_gene,...,120,121,122,123,124,125,126,127,128,model_id
0,Sequence,0,human,H,4.7e-59,189.3,304,430,,,...,Q,G,T,M,V,T,V,S,S,fine_tuned_large_zeroshot_PD1_H
1,Sequence,0,human,H,1.5e-59,190.9,304,427,,,...,Q,G,T,T,V,T,V,S,S,fine_tuned_large_zeroshot_PD1_H


In [7]:
df_result_KL

Unnamed: 0,Id,domain_no,hmm_species,chain_type,e-value,score,seqstart_index,seqend_index,identity_species,v_gene,...,119,120,121,122,123,124,125,126,127,model_id
0,Sequence,1,human,K,3.8999999999999997e-56,179.6,551,657,,,...,G,Q,G,T,R,L,E,I,K,fine_tuned_large_zeroshot_PD1_KL
1,Sequence,1,human,K,9.6e-56,178.3,548,655,,,...,G,Q,G,T,K,V,E,I,K,fine_tuned_large_zeroshot_PD1_KL


In [8]:
target = '''MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQ
                     IKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRG
                     SADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGSEVQLVESGGGLVQPGGSLRLSC'''

example_sequence = '''MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAI
SLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPS
GMGTSSPARRGSADGPRSAQPLRPEDGHCSWPLGGGGGSGGGGSGGGGSEVQLVESGGGLVQPGGSLRLSCAASGFTFSSYGMHWVRQAPGKGLEWVAVIWYDGSNKYYADSVKGRFTISRDNSKNTLY
LQMNSLRAEDTAVYYCARDYGTGDYYYDYWGQGTLVTVSSGGGGSGGGGSGGGGSDIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLT
ISSLQPEDFATYYCQQSYSTLWTFGQGTKVEIK'''

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda
