In [None]:
!pip install transformers, wandb, ml-things, pandas, tqdm

In [1]:
import torch
import wandb
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from ml_things import plot_dict, plot_confusion_matrix
from sklearn.metrics import classification_report, accuracy_score
from transformers import (set_seed,
                          GPT2Config,
                          AdamW,
                          get_linear_schedule_with_warmup,
                          GPT2ForSequenceClassification)

In [2]:
set_seed(123)
epochs = 4
batch_size = 32
max_length = 384
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_path = 'gpt2'
n_labels = 2

In [3]:
wandb.init(project="mhc2seq")

  return LooseVersion(v) >= LooseVersion(check)
  return LooseVersion(v) >= LooseVersion(check)
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
wandb: Currently logged in as: franknoh. Use `wandb login --relogin` to force relogin
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


In [4]:
class MhcSeqDataset(Dataset):
    def __init__(self, path):
        df = pd.read_csv(path, sep='\t')
        self.alignseq = pd.read_csv('data/mhc.tsv', sep='\t')
        self.mhc = df['Mhc'].tolist()
        self.seq = df['Seq'].tolist()
        self.pred = df['Pred'].tolist()
        return

    def mhc2seq(self, mhc):
        seq = self.alignseq[self.alignseq['Mhc'] == mhc]['Seq'].tolist()
        if len(seq) != 0:
            return seq[0]
        else:
            return ''

    def __len__(self):
        return len(self.mhc)

    def __getitem__(self, item):
        while self.mhc2seq(self.mhc[item]) == '':
            item += 1
        return {
            'mhc': self.mhc2seq(self.mhc[item]),
            'seq': self.seq[item],
            'pred': float(self.pred[item])
        }

In [5]:
class Gpt2ClassificationCollator(object):
    def __init__(self, max_sequence_len):
        self.vocab = ['<pad>', '<mhc>', '</mhc>', '<seq>', '</seq>', '.', '*', 'L', 'A', 'G', 'V', 'E', 'S', 'I', 'K', 'R', 'D', 'T', 'P', 'N', 'Q', 'F', 'Y', 'M', 'H', 'C', 'W', 'X', 'U', 'B', 'Z', 'O']
        self.pad_idx = 0
        self.max_sequence_len = max_sequence_len
        return

    def __call__(self, sequences):
        seq = [f"<mhc>{sequence['mhc']}</mhc><seq>{sequence['seq']}</seq>" for sequence in sequences]
        labels = [sequence['pred'] for sequence in sequences]
        inputs = self.encode(seq)
        inputs.update({'labels':torch.tensor(labels)})
        return inputs

    def encode(self, sequence):
        result = []
        for seq in sequence:
            ids = []
            while seq:
                for i in range(len(self.vocab)):
                    if seq.startswith(self.vocab[i]):
                        ids.append(i)
                        seq = seq[len(self.vocab[i]):]
                        break
            ids = ids[:self.max_sequence_len]
            padding_length = self.max_sequence_len - len(ids)
            ids = ids + ([self.pad_idx] * padding_length)
            result.append(ids)
        return {'input_ids':torch.tensor(result)}

In [6]:
def train(dataloader, optimizer_, scheduler_, device_):
    global model
    predictions_labels = []
    true_labels = []
    total_loss = 0
    model.train()
    for batch in tqdm(dataloader, total=len(dataloader)):
        true_labels += batch['labels'].numpy().flatten().tolist()
        batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
        model.zero_grad()
        outputs = model(**batch)
        loss, logits = outputs[:2]
        total_loss += loss.item()
        wandb.log({"loss": loss.item()})
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer_.step()
        scheduler_.step()
        logits = logits.detach().cpu().numpy()
        predictions_labels += logits.argmax(axis=-1).flatten().tolist()
    avg_epoch_loss = total_loss / len(dataloader)
    return true_labels, predictions_labels, avg_epoch_loss

In [7]:
def validation(dataloader, device_):
    global model
    predictions_labels = []
    true_labels = []
    total_loss = 0
    model.eval()
    for batch in tqdm(dataloader, total=len(dataloader)):
        true_labels += batch['labels'].numpy().flatten().tolist()
        batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            loss, logits = outputs[:2]
            logits = logits.detach().cpu().numpy()
            total_loss += loss.item()
            predict_content = logits.argmax(axis=-1).flatten().tolist()
            predictions_labels += predict_content
    avg_epoch_loss = total_loss / len(dataloader)
    return true_labels, predictions_labels, avg_epoch_loss

In [8]:
model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path=model_name_or_path, num_labels=n_labels)
model = GPT2ForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path, config=model_config)
model.resize_token_embeddings(31)
model.config.pad_token_id = 0
model.to(device)
wandb.watch(model)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[]

In [9]:
gpt2_classificaiton_collator = Gpt2ClassificationCollator(max_length)
train_dataset = MhcSeqDataset(path='data/train.tsv')
print('Created `train_dataset` with %d examples!'%len(train_dataset))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=gpt2_classificaiton_collator)
print('Created `train_dataloader` with %d batches!'%len(train_dataloader))
valid_dataset =  MhcSeqDataset(path='data/test.tsv')
print('Created `valid_dataset` with %d examples!'%len(valid_dataset))
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classificaiton_collator)
print('Created `eval_dataloader` with %d batches!'%len(valid_dataloader))

Created `train_dataset` with 22197 examples!
Created `train_dataloader` with 694 batches!
Created `valid_dataset` with 9514 examples!
Created `eval_dataloader` with 298 batches!


In [10]:
optimizer = AdamW(model.parameters(),
                  lr = 2e-5,
                  eps = 1e-8
                  )
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}



In [None]:
for epoch in range(epochs):
    train_labels, train_predict, train_loss = train(train_dataloader, optimizer, scheduler, device)
    train_acc = accuracy_score(train_labels, train_predict)
    valid_labels, valid_predict, val_loss = validation(valid_dataloader, device)
    val_acc = accuracy_score(valid_labels, valid_predict)
    all_loss['train_loss'].append(train_loss)
    all_loss['val_loss'].append(val_loss)
    all_acc['train_acc'].append(train_acc)
    all_acc['val_acc'].append(val_acc)

  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


  0%|          | 0/694 [00:00<?, ?it/s]

In [None]:
wandb.log({
    'train_loss': all_loss['train_loss'],
    'val_loss': all_loss['val_loss'],
    'train_acc': all_acc['train_acc'],
    'val_acc': all_acc['val_acc']
})

In [None]:
plot_dict(all_loss, use_xlabel='Epochs', use_ylabel='Value', use_linestyles=['-', '--'])
plot_dict(all_acc, use_xlabel='Epochs', use_ylabel='Value', use_linestyles=['-', '--'])

In [None]:
true_labels, predictions_labels, avg_epoch_loss = validation(valid_dataloader, device)
evaluation_report = classification_report(true_labels, predictions_labels, labels=[0, 1], target_names=['0', '1'])
print(evaluation_report)
plot_confusion_matrix(y_true=true_labels, y_pred=predictions_labels,
                      classes=['0', '1'], normalize=True,
                      magnify=0.1,
                      )

In [None]:
torch.save(model.state_dict(), 'models/mhc2seq.pt')