In [1]:
import os
import cv2
import time
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torch.utils.data import RandomSampler

from torchvision import transforms
import torchvision.models as models

from matplotlib import pyplot as plt
import copy

from Bio import SeqIO



In [2]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0
[0m

# Prepare Dataset
## Import

In [3]:
DIR_TRAIN = "/kaggle/input/aist4010-spring2023-a2/data/train.fasta"
DIR_VALID = "/kaggle/input/aist4010-spring2023-a2/data/val.fasta"
DIR_TEST = "/kaggle/input/aist4010-spring2023-a2/data/test.fasta"

In [4]:
arg_dict = {'aminoglycoside': 0, 'macrolide-lincosamide-streptogramin': 1, 'polymyxin': 2,
'fosfomycin': 3, 'trimethoprim': 4, 'bacitracin': 5, 'quinolone': 6, 'multidrug': 7,
'chloramphenicol': 8, 'tetracycline': 9, 'rifampin': 10, 'beta_lactam': 11,
'sulfonamide': 12, 'glycopeptide': 13}

In [5]:
# X = sequence, y = class (0-15)
def get_class(ind):
    parsed = ind.split("|")
    if (parsed[0] == 'sp'):
        return 14
    else:
        return arg_dict[parsed[3]]

In [6]:
# parse training data
train_X = []
train_y = []
for index, record in enumerate(SeqIO.parse(DIR_TRAIN, "fasta")):
    train_X.append(str(record.seq))
    train_y.append(get_class(record.id))
print(train_X[:5])
print(train_y[:5])

['MSLNEPIKKVSIVIPVYNEQESLPALIDRTTAACKLLTQAYEIILVDDGSSDNSAELLTTAANDPDSQIIAVLLNRNYGQHSAIMAGFNQVSGDLIITLDADLQNPPEEIPRLVHVAEEGYDVVGTVRANRQDSLFRKTASRMINMMIQRATGKSMGDYGCMLRAYRRHIVEAMLHCHERSTFIPILANTFARRTTEITVHHAEREFGNSKYSLMRLINLMYDLITCLTTTPLRLLSLVGSAIALLGFTFSVLLVALRLIFGPEWAGGGVFTLFAVLFMFIGAQFVGMGLLGEYIGRIYNDVRARPRYFVQKVVGAEQTENNQDVEK', 'MQKPVLIASAALICAAVIGIAVYATGSAKKDAGGFAGYPPVKVALASVERRVVPRVFDGVGELEAGRQVQVAAEAAGRITRIAFESGQQVQQGQLLVQLNDAVEQAELIRLKAQLRNAEILHARARKLVERNVASQEQLDNAVAARDMALGAVRQTQALIDQKAISAPFSGQLGIRRVHLGQYLGVAEPVASLVDARTLKSNFSLDESTSPELKLGQPLEVLVDAYPGRSFPARISAIDPLIGKSRTVQVQALLDNPEGLLAAGMFASIRVSRKADAPSLSVPETAVTYTAYGDTVFVAHQDGDRPLSAKRVSVRIGERWDGRVEILQGLAEGDRVVTSGQINLSDGMAVEPVKEDTLSSAAPPVPVAGR', 'MTALLELKGIRRSYQSGGETVDVLQDVSLTINAGELVAIIGASGSGKSTLMNILGCLDKPSAGIYRVAGQDVATLDNDALAALRREHFGFIFQRYHLLPHLSAAHNVEVPAVYAGLGKHERRERANMLLTRLGLEERVNYQPNQLSGGQQQRVSIARALMNGGQVILADEPTGALDSHSSVEVMAILKQLQQQGHTVIIVTHDPNVAAQAERIIEIKDGRIMADSGSKTVPTVVASEAVSLAPSAPSWQQLAGRFREALLMAWRAMSANKMRTALTMLGIIIGIASVVSILVV

In [7]:
# parse validation data
val_X = []
val_y = []
for index, record in enumerate(SeqIO.parse(DIR_VALID, "fasta")):
    val_X.append(str(record.seq))
    val_y.append(get_class(record.id))
print(val_X[:5])
print(val_y[:5])

['MTDVIKAIILGIIEGLTEFLPVSSTGHLILAGNLLSFEGDAAITFKIVIQLGAVMAVLILYWKRYLEIGANMIRMDFSKSKGLNVIHMILAMLPALILYLLFKDTIKSQLFGPTPVLIGLVVGGVLMIIAARNRRTETADTIDGINYKQAFGIGLFQCLALWPGFSRSGSTISGGLLLGTSQKAAADFTFIISVPVMFGASLLDLYDSRDLLSSDDLILMLIGFATSFLVAMIAVVTFIKLIKRLRLEWFALYRFVLAALFYLIVIQ', 'MDIIFALKALVMGLVEGFTEFLPISSTGHLILAGSLLDLQRNVSKEVIDVFEIVIQAGAILAVCWEYRARIASVLSGLTSDHKARKFVLNLIVAFLPLAVLGLAVGKHIKAVLFKPVPVALAFIIGGFVILWAERRAKTNPTAVRIHSVEDMSVSDALKVGFAQAFALIPGTSRSGATIIGGMLFGLSRKAGTEFSFFLAIPTLLCATFYSLYKERALLSADLTGFFSIGTVAAFVSAFLCVRWLLRYISSHDFTVFAWYRIVFGLVVIVTSYTGMVAWVD', 'MKQNIRQREGVFLENIIFVIKSVILGIVEGITEFLPVSSTGHLVIFQNLIGFKGITDKYVEMYTYVIQLGAILAVIVLYWRKIVETLINFFPGKVSYEKSGFRFWFIIFIACIPGGVFGILLDDLAEQYLFSPVTVAITLFLGALWMIYAENTFKNKSAANIRNSLGSDLKITTRQAVIIGLFQCLAIIPGMSRSASTIIGGWISGLSTVAAAEFSFFLAIPVMVGMSFLKIFKIGGLLSLTHLELISLGVGFAVSFGVALIVIEKFISYLQKKPMKIFAVYRIIFAVVVLITGFLGIF', 'MFKITLCALLITASCSTFAAPQQINDIVHRTITPLIEQQKIPGMAVAVIYQGKPYYFTWGYADIAKKQPVTQQTLFELGSVSKTFTGVLGGDAIARGEIKLSDPATKYWPELTAKQWNGITLLHLATYTAGGLPLQVPD

## Tokenize

Checkpoint name	Num layers	Num parameters

esm2_t48_15B_UR50D	48	15B

esm2_t36_3B_UR50D	36	3B

esm2_t33_650M_UR50D	33	650M

esm2_t30_150M_UR50D	30	150M

esm2_t12_35M_UR50D	12	35M

esm2_t6_8M_UR50D	6	8M

In [8]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D" 
epochs = 4

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [10]:
tokenizer(train_X[0])

{'input_ids': [0, 20, 8, 4, 17, 9, 14, 12, 15, 15, 7, 8, 12, 7, 12, 14, 7, 19, 17, 9, 16, 9, 8, 4, 14, 5, 4, 12, 13, 10, 11, 11, 5, 5, 23, 15, 4, 4, 11, 16, 5, 19, 9, 12, 12, 4, 7, 13, 13, 6, 8, 8, 13, 17, 8, 5, 9, 4, 4, 11, 11, 5, 5, 17, 13, 14, 13, 8, 16, 12, 12, 5, 7, 4, 4, 17, 10, 17, 19, 6, 16, 21, 8, 5, 12, 20, 5, 6, 18, 17, 16, 7, 8, 6, 13, 4, 12, 12, 11, 4, 13, 5, 13, 4, 16, 17, 14, 14, 9, 9, 12, 14, 10, 4, 7, 21, 7, 5, 9, 9, 6, 19, 13, 7, 7, 6, 11, 7, 10, 5, 17, 10, 16, 13, 8, 4, 18, 10, 15, 11, 5, 8, 10, 20, 12, 17, 20, 20, 12, 16, 10, 5, 11, 6, 15, 8, 20, 6, 13, 19, 6, 23, 20, 4, 10, 5, 19, 10, 10, 21, 12, 7, 9, 5, 20, 4, 21, 23, 21, 9, 10, 8, 11, 18, 12, 14, 12, 4, 5, 17, 11, 18, 5, 10, 10, 11, 11, 9, 12, 11, 7, 21, 21, 5, 9, 10, 9, 18, 6, 17, 8, 15, 19, 8, 4, 20, 10, 4, 12, 17, 4, 20, 19, 13, 4, 12, 11, 23, 4, 11, 11, 11, 14, 4, 10, 4, 4, 8, 4, 7, 6, 8, 5, 12, 5, 4, 4, 6, 18, 11, 18, 8, 7, 4, 4, 7, 5, 4, 10, 4, 12, 18, 6, 14, 9, 22, 5, 6, 6, 6, 7, 18, 11, 4, 18, 5, 7, 4, 1

In [11]:
train_X_token = tokenizer(train_X, max_length=1024, truncation=True)
val_X_token = tokenizer(val_X, max_length=1024, truncation=True)

In [12]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_X_token).add_column("labels", train_y)
val_dataset = Dataset.from_dict(val_X_token).add_column("labels", val_y)

# train_dataset = Dataset.from_dict(train_X_token)
# val_dataset = Dataset.from_dict(val_X_token)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 21209
})

# Configure Model

In [13]:
from transformers import TFAutoModelForSequenceClassification

num_labels = 15
model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Downloading (…)lve/main/config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Downloading (…)"tf_model.h5";:   0%|          | 0.00/134M [00:00<?, ?B/s]

Some layers from the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing TFEsmForSequenceClassification: ['lm_head']
- This IS expected if you are initializing TFEsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFEsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFEsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
tf_train_set = model.prepare_tf_dataset(
    train_dataset,
    batch_size=8,
    shuffle=True,
    tokenizer=tokenizer
)

tf_val_set = model.prepare_tf_dataset(
    val_dataset,
    batch_size=8,
    shuffle=False,
    tokenizer=tokenizer
)

In [15]:
from transformers import AdamWeightDecay

model.compile(optimizer=AdamWeightDecay(2e-5), metrics=["accuracy"])

No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


In [16]:
model.fit(tf_train_set, validation_data=tf_val_set, epochs=epochs)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7fda645cc2d0>

# Inference

In [17]:
test_X = []
test_id = []
for index, record in enumerate(SeqIO.parse(DIR_TEST, "fasta")):
    test_X.append(str(record.seq))
    test_id.append(record.id)
print(test_X[:5])
print(test_id[:5])

['MYTKNAAIVLRLMTESDLPMLHAWLNRPHIVEWWGGEDKRPTLGEVLEHYSPRVLAEQAVVPYIAMLDDEPIGYAQSYTALGSGDGWWEDETDPGVRGIDQSLANPSQLNKGLGTTLVRSLVELLFSDPAVSKIQTDPSPNNHRAIRCYEKAGFAQDKIILTPDGPAVYMVQTRQAFESQRNAA', 'MLILTKTAGVFFKPSKRKVYEFLRSFNFHPGTLFLHKIVLGIETSCDDTAAAVVDETGNVLGEAIHSQTEVHLKTGGIVPPAAQQLHRENIQRIVQEALSASGVSPSDLSAIATTIKPGLALSLGVGLSFSLQLVGQLKKPFIPIHHMEAHALTIRLTNKVEFPFLVLLISGGHCLLALVQGVSDFLLLGKSLDIAPGDMLDKVARRLSLIKHPECSTMSGGKAIEHLAKQGNRFHFDIKPPLHHAKNCDFSFTGLQHVTDKIIMKKEKEEGIEKGQILSSAADIAATVQHTMACHLVKRTHRAILFCKQRDLLPQNNAVLVASGGVASNFYIRRALEILTNATQCTLLCPPPRLCTDNGIMIAWNGIERLRAGLGILHDIEGIRYEPKCPLGVDISKEVGEASIKVPQLKMEI', 'MRFTLLAFALAVALPAAHASAAEAPLPQLRAYTVDASWLQPMAPLQVADHTWQIGTEDLTALLVQTAEGAVLLDGGMPQMAGHLLDNMKLRGVAPQDLRLILLSHAHADHAGPVAELKRRTGAHVAANAETAVLLARGGSNDLHFGDGITYPPASADRIIMDGEVVTVGGIAFTAHFMPGHTPGSTAWTWTDTRDGKPVRIAYADSLSAPGYQLKGNPRYPRLIEDYKRSFATVRALPCDLLLTPHPGASNWNYAAGSKASAEALTCNAYADAAEKKFDAQLARETAGTR', 'MQNAHRSDTGAAALTGTPEKLLPTQPETGSFQVVLDDVVRAPGGRPLLDGVNQSVALGERVGIIGENGSGKSTLLRMLAGVDRPDGGQVLVRAPGGCG

In [18]:
test_X_token = tokenizer(test_X, truncation=True)

In [19]:
test_dataset = Dataset.from_dict(test_X_token)

tf_test_set = model.prepare_tf_dataset(
    test_dataset,
    batch_size=8,
    shuffle=False,
    tokenizer=tokenizer
)

In [20]:
test_pred = model.predict(tf_test_set)



In [21]:
test_pred

TFSequenceClassifierOutput(loss=None, logits=array([[ 6.5642085 , -0.45424613, -0.17627501, ..., -1.6900734 ,
        -0.54871094,  0.4278924 ],
       [-0.8514332 , -1.9312547 , -1.7375091 , ..., -3.4121675 ,
        -0.7003106 ,  8.489716  ],
       [ 0.26391718, -1.4327399 , -1.1234194 , ..., -0.8700567 ,
        -1.6571702 , -0.41038987],
       ...,
       [ 0.18184584, -1.5998174 , -1.2056975 , ..., -0.73529553,
        -1.5380272 , -0.5527186 ],
       [-1.2070526 , -2.1450577 , -1.2531396 , ..., -3.017218  ,
         0.58309746,  7.2840786 ],
       [-0.55176646,  0.55046   , -1.4520098 , ..., -0.5998809 ,
        -0.08526545, -1.0393775 ]], dtype=float32), hidden_states=None, attentions=None)

In [22]:
test_label = np.argmax(test_pred.logits, axis=1)
test_label

array([ 0, 14, 11, ..., 11, 14,  9])

In [23]:
data = {"id":test_id, "label":test_label}
df = pd.DataFrame(data)
df.to_csv('submission.csv',index = False)