In [1]:
from google.colab import drive
drive.mount('/content/drive/')

import os
os.chdir('/content/drive/My Drive/Colab Notebooks/refine-epitope-deep-learning')

Mounted at /content/drive/


In [None]:
!pip install transformers
!pip install optuna
!pip install SentencePiece

# Preprocess data 

In [3]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import torch
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification
#from transformers import EarlyStoppingCallback
from transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification
from transformers import EarlyStoppingCallback

In [4]:
# Create torch dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [5]:
def preprocess_data(data):

    # Preprocess data
    X = list(data["sequence"])
    y = list(data["label"])
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1)
    X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
    X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)

    train_dataset = Dataset(X_train_tokenized, y_train)
    val_dataset = Dataset(X_val_tokenized, y_val)
    return train_dataset, val_dataset

In [6]:
df_train = pd.read_csv("./input/data_train.csv")

sequence_formatted = []
for seq in df_train['sequence'].values:
  sequence_formatted.append(" ".join(seq))

data = pd.DataFrame({'sequence':sequence_formatted, 'label':df_train['label'].tolist()})

data_op = data[:int(len(data)/7)]


# Define pretrained tokenizer and model
batch_size=8
model_name = "xlnet-base-cased"

tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=True)
model = XLNetForSequenceClassification.from_pretrained(model_name, num_labels=2)



train_dataset_op, val_dataset_op = preprocess_data(data_op)
train_dataset, val_dataset = preprocess_data(data)

# ----- 2. Fine-tune pretrained model -----#
# Define Trainer parameters
def compute_metrics(p):
    
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred)
    precision = precision_score(y_true=labels, y_pred=pred)
    f1 = f1_score(y_true=labels, y_pred=pred)

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

# Define Trainer
args = TrainingArguments(
    f"{model_name}-finetuned-classification",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    
    #evaluation_strategy ='steps',
    #eval_steps = 50, # Evaluation and Save happens every 50 steps
    #save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    optim="adamw_torch"
)

def model_init():
    return model

trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=train_dataset_op,
    eval_dataset=val_dataset_op,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
  #  callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)


best_run = trainer.hyperparameter_search(n_trials=5, direction="maximize")



Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/760 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.402374,0.75,1.0,0.5,0.666667
2,No log,0.237207,0.916667,0.857143,1.0,0.923077


***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-0/checkpoint-14
Configuration saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-14/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-14/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-14/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-14/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-0/checkpoint-28
Configuration saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-28/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-0/checkpoint-28/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetune

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.238262,0.916667,0.857143,1.0,0.923077
2,No log,0.223975,0.916667,0.857143,1.0,0.923077
3,No log,0.227401,0.916667,0.857143,1.0,0.923077
4,No log,0.240175,0.916667,0.857143,1.0,0.923077
5,No log,0.223336,0.916667,0.857143,1.0,0.923077


***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-1/checkpoint-14
Configuration saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-14/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-14/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-14/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-14/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-1/checkpoint-28
Configuration saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-28/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-1/checkpoint-28/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetune

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.285204,0.916667,0.857143,1.0,0.923077


***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-2/checkpoint-14
Configuration saved in xlnet-base-cased-finetuned-classification/run-2/checkpoint-14/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-2/checkpoint-14/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/run-2/checkpoint-14/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/run-2/checkpoint-14/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from xlnet-base-cased-finetuned-classification/run-2/checkpoint-14 (score: 0.923076923076923).
[32m[I 2022-06-21 01:56:06,979][0m Trial 2 finished with value: 3.6968864468864466 and parameters: {'learning_rate': 1.6103970383052136e-05, 'num_train_epochs': 1, 'seed': 19, 'per_device_train_batch_size': 8}. 

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.257046,0.916667,0.857143,1.0,0.923077
2,No log,0.278371,0.916667,0.857143,1.0,0.923077
3,No log,0.290361,0.833333,0.833333,0.833333,0.833333
4,No log,0.251457,0.916667,0.857143,1.0,0.923077
5,No log,0.249401,0.916667,0.857143,1.0,0.923077


***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-3/checkpoint-14
Configuration saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-14/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-14/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-14/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-14/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-3/checkpoint-28
Configuration saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-28/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-3/checkpoint-28/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetune

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.254451,0.916667,0.857143,1.0,0.923077


***** Running Evaluation *****
  Num examples = 12
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/run-4/checkpoint-14
Configuration saved in xlnet-base-cased-finetuned-classification/run-4/checkpoint-14/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/run-4/checkpoint-14/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/run-4/checkpoint-14/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/run-4/checkpoint-14/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from xlnet-base-cased-finetuned-classification/run-4/checkpoint-14 (score: 0.923076923076923).
[32m[I 2022-06-21 01:58:04,977][0m Trial 4 finished with value: 3.6968864468864466 and parameters: {'learning_rate': 1.0325169188158175e-06, 'num_train_epochs': 1, 'seed': 3, 'per_device_train_batch_size': 32}. 

## Set the model with the best parameters and run it on the full dataset

In [11]:
for n, v in best_run.hyperparameters.items():
    setattr(trainer.args, n, v)

trainer.train_dataset=train_dataset
trainer.eval_dataset=val_dataset

trainer.train()

***** Running training *****
  Num examples = 743
  Num Epochs = 2
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 186


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.854726,0.795181,0.933333,0.466667,0.622222
2,No log,0.312155,0.891566,0.862069,0.833333,0.847458


***** Running Evaluation *****
  Num examples = 83
  Batch size = 8


Saving model checkpoint to xlnet-base-cased-finetuned-classification/checkpoint-93
Configuration saved in xlnet-base-cased-finetuned-classification/checkpoint-93/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/checkpoint-93/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/checkpoint-93/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-classification/checkpoint-93/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 83
  Batch size = 8
Saving model checkpoint to xlnet-base-cased-finetuned-classification/checkpoint-186
Configuration saved in xlnet-base-cased-finetuned-classification/checkpoint-186/config.json
Model weights saved in xlnet-base-cased-finetuned-classification/checkpoint-186/pytorch_model.bin
tokenizer config file saved in xlnet-base-cased-finetuned-classification/checkpoint-186/tokenizer_config.json
Special tokens file saved in xlnet-base-cased-finetuned-cla

TrainOutput(global_step=186, training_loss=0.37586675664430025, metrics={'train_runtime': 178.3775, 'train_samples_per_second': 8.331, 'train_steps_per_second': 1.043, 'total_flos': 423332095414272.0, 'train_loss': 0.37586675664430025, 'epoch': 2.0})

In [8]:
# Load trained model
#model_path = "xlnet-base-cased-finetuned-classification/checkpoint-3168"
#model = XLNetForSequenceClassification.from_pretrained(model_path, num_labels=2)


# Define test trainer
#trainer = Trainer(model)

In [9]:
# ----- 3. Predict -----#
# Load test data
#test_data = pd.read_csv("test.csv")
test = pd.read_csv("./input/data_test.csv")

sequence_formatted = []
for seq in test['sequence'].values:
  sequence_formatted.append(" ".join(seq))

test_data = pd.DataFrame({'sequence':sequence_formatted, 'label':test['label'].tolist()})


X_test = list(test_data["sequence"])
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)

# Create torch dataset
test_dataset = Dataset(X_test_tokenized)

# Make prediction
raw_pred, _, _ = trainer.predict(test_dataset)

# Preprocess raw predictions
y_pred = np.argmax(raw_pred, axis=1)

***** Running Prediction *****
  Num examples = 207
  Batch size = 8


In [10]:
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import roc_auc_score

print("ROC_AUC:", roc_auc_score(test_data['label'], y_pred))

print(classification_report(test_data['label'], y_pred))

ROC_AUC: 0.8515087853323147
              precision    recall  f1-score   support

           0       0.86      0.91      0.88       119
           1       0.86      0.80      0.83        88

    accuracy                           0.86       207
   macro avg       0.86      0.85      0.86       207
weighted avg       0.86      0.86      0.86       207

