In [1]:
%%capture
!pip install transformers datasets
!pip install accelerate -U
!pip install evaluate

In [2]:
import os
import pandas as pd
import gzip

# Increase the maximum number of columns to display
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)

# Directory where your .csv.gz files are stored
directory = './heavy-sequences/'

In [3]:
for filename in os.listdir(directory):
    print(filename)

.DS_Store
config.json
ERR4077966_Heavy_IGHE.csv.gz
checkpoint-1220
SRR3099408_Heavy_Bulk.csv.gz
pytorch_model.bin
checkpoint-2440
1279049_1_Heavy_IGHM.csv.gz
SRR8283845_1_Heavy_IGHA.csv.gz
SRR5811777_1_Heavy_IGHE.csv.gz
SRR4297400_Heavy_Bulk.csv.gz
checkpoint-3660
SRR12326777_1_Heavy_IGHE.csv.gz


In [4]:
# List to store all the dataframes
dfs = []

for filename in os.listdir(directory):
    if filename.endswith('.csv.gz'):  # Only process .csv.gz files
        with gzip.open(directory + filename, 'rt') as f:
            # Save the first line (metadata) to a variable
            metadata = next(f)
            # Now load the rest of the file into a DataFrame
            df = pd.read_csv(f)

            # Define columns to drop
            columns_to_drop = [
                'locus', 'stop_codon', 'vj_in_frame', 'v_frameshift', 'productive', 'rev_comp', 'complete_vdj', 'v_call',
                'd_call', 'j_call', 'sequence_alignment', 'germline_alignment', 'sequence',
                'germline_alignment_aa', 'v_alignment_start', 'v_alignment_end', 'd_alignment_start', 'd_alignment_end',
                'j_alignment_start', 'j_alignment_end', 'v_sequence_alignment', 'v_sequence_alignment_aa',
                'v_germline_alignment', 'v_germline_alignment_aa', 'd_sequence_alignment', 'd_sequence_alignment_aa',
                'd_germline_alignment', 'd_germline_alignment_aa', 'j_sequence_alignment', 'j_sequence_alignment_aa',
                'j_germline_alignment', 'j_germline_alignment_aa', 'junction', 'junction_length', 'junction_aa',
                'junction_aa_length', 'v_score', 'd_score', 'j_score', 'v_cigar', 'd_cigar', 'j_cigar', 'v_support',
                'd_support', 'j_support', 'v_identity', 'd_identity', 'j_identity', 'v_sequence_start', 'v_sequence_end',
                'v_germline_start', 'v_germline_end', 'd_sequence_start', 'd_sequence_end', 'd_germline_start',
                'd_germline_end', 'j_sequence_start', 'j_sequence_end', 'j_germline_start', 'j_germline_end', 'np1',
                'np1_length', 'np2', 'np2_length', 'c_region', 'Redundancy', 'ANARCI_numbering', 'ANARCI_status', 'fwr1_aa', 'cdr1', 'cdr2', 'cdr3', 'fwr1', 'fwr2', 'fwr3', 'fwr4', 'fwr2_aa', 'fwr3_aa', 'fwr4_aa', 'fwr1_start', 'fwr1_end', 'fwr2_start', 'fwr2_end',
                'fwr3_start', 'fwr3_end', 'fwr4_start', 'fwr4_end'
            ]

            df = df.drop(columns=columns_to_drop)

            # Replace each instance of double double quotes with a single double quote
            metadata = metadata.replace('""', '"')
            metadata = metadata[1:-1]

            # Remove extra quotes and leading/trailing whitespace
            metadata = metadata.replace('""', '"').strip()

            start_index_chain = metadata.find('"Chain": "') + len('"Chain": "')  # Find the start index of the "Chain" value
            end_index_chain = metadata.find('"', start_index_chain)  # Find the end index of the "Chain" value
            chain = metadata[start_index_chain:end_index_chain]  # Extract the value of "Chain"

            start_index_disease = metadata.find('"Disease": "') + len('"Disease": "')  # Find the start index of the "Disease" value
            end_index_disease = metadata.find('"', start_index_disease)  # Find the end indexApologies for the abrupt cutoff in the previous message. Here is the rest of the code:
                        # of the "Disease" value
            disease = metadata[start_index_disease:end_index_disease]  # Extract the value of "Disease"

            # Add new columns 'Chain' and 'Disease'
            df['Chain'] = chain
            df['Disease'] = disease
            df['meta'] = metadata

            # Append this DataFrame to the list
            dfs.append(df)

# Combine all the dataframes in the list into one big dataframe
final_df = pd.concat(dfs, ignore_index=True)
df = final_df

In [5]:
df

Unnamed: 0,sequence_alignment_aa,cdr1_aa,cdr2_aa,cdr3_aa,cdr1_start,cdr1_end,cdr2_start,cdr2_end,cdr3_start,cdr3_end,Chain,Disease,meta
0,VQLQESGPGLVKPSETLSLTCPVSGGSISTYYWSWIRKTPGKGLEW...,GGSISTYY,IYYSEST,ARVAGTYGGFGQLYFDY,84.0,107.0,159.0,179.0,294.0,344.0,Heavy,Tonsillitis,"{""Run"": ""ERR4077966"", ""Link"": ""https://doi.org..."
1,VQLQESGPGLVKPSETLSLTCNVSGGSISSGQWSWIRQPPGKGLEW...,GGSISSGQ,FYYSGST,AGDYGCRY,82.0,105.0,157.0,177.0,292.0,315.0,Heavy,Tonsillitis,"{""Run"": ""ERR4077966"", ""Link"": ""https://doi.org..."
2,QLQESGPGVVKPSETLSLTCTVSGGSISSGDHYWAWIRQPPGKGLE...,GGSISSGDHY,MYYSGTI,ARYVRASFDE,84.0,113.0,165.0,185.0,300.0,329.0,Heavy,Tonsillitis,"{""Run"": ""ERR4077966"", ""Link"": ""https://doi.org..."
3,VQLVESGGGVVQPGRSLRLSCAASGFTFNNYGMHWVRQAPGKGLEG...,GFTFNNYG,IWYDGDNK,ARAPYSTTGYFDY,85.0,108.0,160.0,183.0,298.0,336.0,Heavy,Tonsillitis,"{""Run"": ""ERR4077966"", ""Link"": ""https://doi.org..."
4,VQLVESGGGLVQPGGSLRLSCAASGFTFSSYEMNWVRQAPGKGLEW...,GFTFSSYE,ISSSDRTI,ARVSTQLYSQYSFDY,83.0,106.0,158.0,181.0,296.0,340.0,Heavy,Tonsillitis,"{""Run"": ""ERR4077966"", ""Link"": ""https://doi.org..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
52029,DSVKGRFTHYRDNAKNTLYMQMNSLGDEDTAIFYCARTSLYVRGSY...,,,ARTSLYVRGSYEHY,,,,,112.0,153.0,Heavy,SARS-COV-2,"{""Run"": ""SRR12326777"", ""Link"": ""https://dx.doi..."
52030,YYNPSLHSRVTLSIHTSKRQFSLKLTSVTAADTAIYFCARGPPRLG...,,,ARGPPRLGFDY,,,2.0,3.0,118.0,150.0,Heavy,SARS-COV-2,"{""Run"": ""SRR12326777"", ""Link"": ""https://dx.doi..."
52031,YYNPSLNSRVTLSINTSKRQFSLKLTSVTAADTAIYFCARGPPRLG...,,,ARGPPRLGFDY,,,2.0,3.0,118.0,150.0,Heavy,SARS-COV-2,"{""Run"": ""SRR12326777"", ""Link"": ""https://dx.doi..."
52032,TYYHPSLNSRVTLSISTSKRQFSLKLTSVTAADTAIYFCARGPPRL...,,T,ARGPPRLGFDY,,,1.0,3.0,118.0,150.0,Heavy,SARS-COV-2,"{""Run"": ""SRR12326777"", ""Link"": ""https://dx.doi..."


In [6]:
# Step 1: Rename the column
df = df.rename(columns={'sequence_alignment_aa': 'sequence'})

# Step 2: Make 'sequence' the first column
cols = df.columns.tolist()
cols.insert(0, cols.pop(cols.index('sequence')))
df = df[cols]

print(df.head())
print(len(df))
df['sequence_length'] = df['sequence'].apply(len)
print(df['sequence_length'])
print(len(df))

                                            sequence     cdr1_aa   cdr2_aa            cdr3_aa  cdr1_start  cdr1_end  cdr2_start  cdr2_end  cdr3_start  cdr3_end  Chain      Disease                                               meta
0  VQLQESGPGLVKPSETLSLTCPVSGGSISTYYWSWIRKTPGKGLEW...    GGSISTYY   IYYSEST  ARVAGTYGGFGQLYFDY        84.0     107.0       159.0     179.0       294.0     344.0  Heavy  Tonsillitis  {"Run": "ERR4077966", "Link": "https://doi.org...
1  VQLQESGPGLVKPSETLSLTCNVSGGSISSGQWSWIRQPPGKGLEW...    GGSISSGQ   FYYSGST           AGDYGCRY        82.0     105.0       157.0     177.0       292.0     315.0  Heavy  Tonsillitis  {"Run": "ERR4077966", "Link": "https://doi.org...
2  QLQESGPGVVKPSETLSLTCTVSGGSISSGDHYWAWIRQPPGKGLE...  GGSISSGDHY   MYYSGTI         ARYVRASFDE        84.0     113.0       165.0     185.0       300.0     329.0  Heavy  Tonsillitis  {"Run": "ERR4077966", "Link": "https://doi.org...
3  VQLVESGGGVVQPGRSLRLSCAASGFTFNNYGMHWVRQAPGKGLEG...    GFTFNNYG  IWYDGDNK  

In [7]:
df.drop([ 'cdr2_start', 'cdr2_end', 'cdr3_start', 'cdr3_end', 'meta', 'sequence_length'], axis=1, inplace=True)
df.head(3)

Unnamed: 0,sequence,cdr1_aa,cdr2_aa,cdr3_aa,cdr1_start,cdr1_end,Chain,Disease
0,VQLQESGPGLVKPSETLSLTCPVSGGSISTYYWSWIRKTPGKGLEW...,GGSISTYY,IYYSEST,ARVAGTYGGFGQLYFDY,84.0,107.0,Heavy,Tonsillitis
1,VQLQESGPGLVKPSETLSLTCNVSGGSISSGQWSWIRQPPGKGLEW...,GGSISSGQ,FYYSGST,AGDYGCRY,82.0,105.0,Heavy,Tonsillitis
2,QLQESGPGVVKPSETLSLTCTVSGGSISSGDHYWAWIRQPPGKGLE...,GGSISSGDHY,MYYSGTI,ARYVRASFDE,84.0,113.0,Heavy,Tonsillitis


In [8]:
def map_cdr_to_sequence(full_seq, cdr1_seq, cdr2_seq, cdr3_seq):
    mapping = [0]*len(full_seq)
    for cdr_seq, num in zip([cdr1_seq, cdr2_seq, cdr3_seq], [1, 2, 3]):
        cdr_start = full_seq.find(cdr_seq)
        if cdr_start != -1:
            for i in range(len(cdr_seq)):
                mapping[cdr_start + i] = num
    return mapping

In [9]:
df['sequence'] = df['sequence'].astype(str)
df['cdr1_aa'] = df['cdr1_aa'].astype(str)
df['cdr2_aa'] = df['cdr2_aa'].astype(str)
df['cdr3_aa'] = df['cdr3_aa'].astype(str)
df = df.dropna(subset=['sequence', 'cdr1_aa', 'cdr2_aa', 'cdr3_aa'])

df['sequence_classification'] = df.apply(lambda row: map_cdr_to_sequence(row['sequence'], row['cdr1_aa'], row['cdr2_aa'], row['cdr3_aa']), axis=1)

In [10]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(df['sequence'], df['sequence_classification'], test_size=0.25, shuffle=True)

In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

train_tokenized = tokenizer(list(train_sequences))
test_tokenized = tokenizer(list(test_sequences))

In [12]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [13]:
train_dataset[0]

{'input_ids': [0,
  6,
  6,
  8,
  4,
  10,
  4,
  8,
  23,
  5,
  5,
  8,
  6,
  18,
  12,
  18,
  8,
  8,
  19,
  5,
  20,
  21,
  22,
  7,
  10,
  16,
  5,
  14,
  6,
  15,
  6,
  4,
  9,
  22,
  7,
  5,
  7,
  12,
  8,
  8,
  13,
  6,
  8,
  9,
  11,
  19,
  19,
  5,
  13,
  8,
  7,
  15,
  6,
  10,
  18,
  11,
  12,
  8,
  10,
  13,
  17,
  8,
  15,
  17,
  14,
  4,
  18,
  4,
  16,
  20,
  8,
  8,
  4,
  10,
  7,
  9,
  13,
  11,
  5,
  4,
  19,
  19,
  23,
  5,
  10,
  9,
  13,
  19,
  19,
  6,
  10,
  6,
  7,
  4,
  13,
  5,
  4,
  13,
  12,
  22,
  6,
  16,
  6,
  11,
  20,
  7,
  11,
  7,
  8,
  8,
  2],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,


In [14]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [15]:
import torch
# Set the device      
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [16]:
id2label = {
    0: "Non-CDR",
    1: "CDR1",
    2: "CDR2",
    3: "CDR3",
}
label2id = {
    "Non-CDR": 0,
    "CDR1": 1,
    "CDR2": 2,
    "CDR3": 3,

}

In [17]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D", num_labels=4)
model.to(device)

Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForTokenClassification: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for p

EsmForTokenClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, el

In [18]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)

In [19]:
training_args = TrainingArguments(
    output_dir=directory,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    use_mps_device=True
)



In [20]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    
)

trainer.train()

  src_lengths = attention_mask.sum(-1)


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0032,0.001638,0.999792
2,0.0009,0.000754,0.999889
3,0.0004,0.000523,0.999914


TrainOutput(global_step=3660, training_loss=0.011272839177908793, metrics={'train_runtime': 1020.9565, 'train_samples_per_second': 114.672, 'train_steps_per_second': 3.585, 'total_flos': 677492713481010.0, 'train_loss': 0.011272839177908793, 'epoch': 3.0})

In [29]:
def convert_array(input_array):
    output_array = []
    cdr_positions = {1: [None, None], 2: [None, None], 3: [None, None]}

    for idx, label in enumerate(input_array):
        if label == 'LABEL_0':
            output_array.append('0')
        elif label in ['LABEL_1', 'LABEL_2', 'LABEL_3']:
            label_num = int(label[-1])
            output_array.append(str(label_num))
            if cdr_positions[label_num][0] is None:
                cdr_positions[label_num][0] = idx
            cdr_positions[label_num][1] = idx

    output_string = ''.join(output_array)
    return output_string, cdr_positions


In [30]:
import torch

sequence = 'QVQLVQSGAEVRKPGASVKVSCKASGYSFTDYYMHWVRQAPGQGLEWMGWINPKSGGTNYAQRFQGRVTMTGDTSISAAYMDLASLTSDDTAVYYCVKDCGSGGLRDFWGQGTTVTVSS'
inputs = tokenizer(sequence, return_tensors="pt").to(device)
logits = model(**inputs).logits
predicted_token_class_ids = torch.argmax(logits, dim=-1)
predicted_token_class = [model.config.id2label[t] for t in predicted_token_class_ids[0].tolist()]

output_string, cdr_positions = convert_array(predicted_token_class)

print(output_string)

for cdr, positions in cdr_positions.items():
    start, end = positions
    if start is not None and end is not None:
        cdr_sequence = sequence[start:end+1]
        print(f"CDR{cdr}: {cdr_sequence}")


0000000000000000000000000111111110000000000000000022222222000000000000000000000000000000000000003333333333330000000000000
CDR1: GYSFTDYY
CDR2: INPKSGGT
CDR3: VKDCGSGGLRDF


In [31]:
model.save_pretrained(directory)