In [None]:
import evaluate
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from pyfaidx import Fasta
from torch.nn.functional import one_hot
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    pipeline,
)

In [None]:
train_terms = pd.read_csv("train_terms.tsv", sep='\t')
seq = Fasta("train_sequences.fasta")

In [None]:
# Sample a subnet of the train dataset if needed
train_terms = train_terms.sample(frac=0.0001)

In [None]:
unique_terms = train_terms['term'].unique()
id2label = {idx: term for idx, term in enumerate(unique_terms)}
label2id = {term: idx for idx, term in enumerate(unique_terms)}

seqs = {seq[key].name.split('|')[1] : seq[key][:].seq for key in seq.keys()}
train_terms['seq'] = train_terms['EntryID'].map(lambda x: seqs[x])
train_terms['label'] = train_terms['term'].map(lambda x: one_hot(torch.tensor(label2id[x]), num_classes=unique_terms.size).numpy().astype(float))
train_terms = train_terms.groupby('EntryID').agg({'label': 'sum', 'seq': 'first'})

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def batch_tokenize(examples):
    return tokenizer(examples['seq'])

dataset = Dataset.from_pandas(train_terms[['seq', 'label']])

dataset = dataset.map(batch_tokenize, batched=True)
dataset = dataset.train_test_split(test_size=0.1)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D", problem_type="multi_label_classification", id2label=id2label, label2id=label2id
)
    
training_args = TrainingArguments(
    output_dir="./model",
    disable_tqdm=False,
    report_to="none",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

In [None]:
trainer.train()

In [None]:
# Select the checkpoint you want
classifier = pipeline(task="text-classification", model="./model/<checkpoint>", top_k=None)

In [None]:
classifier('<some amino acid sequence>')