Google Colab Notebook for Training Transformers

In [1]:
!pip install transformers



In [2]:
import torch
from torch import nn
from transformers import ElectraModel
# from transformers import ElectraTokenizer
from transformers import BertTokenizer, BertModel

### BERT and Electra use the same tokenizer
# ELECTRA_TOKENIZER = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
# ELECTRA_BASE_TOKENIZER = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
BERT_TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased')

class Electra(nn.Module):
    def __init__(self, output_size: int, size = 'small', device='cpu'):
        super().__init__()
        self.device = device
        self.model = ElectraModel.from_pretrained(f'google/electra-{size}-discriminator').to(device)
        self.output = nn.Linear(self.model.config.hidden_size, output_size).to(device)

    # What happens when passing input into the model.
    def forward(self, sents, locs):
        sents = self.model(sents)[0]
        abbs = torch.stack([sents[n, idx, :] for n, idx in enumerate(locs)]) 
        return self.output(abbs)

class Bert(nn.Module):
    def __init__(self, output_size, device='cpu'):
        super().__init__()
        self.device = device
        self.model = BertModel.from_pretrained('bert-base-uncased').to(device)
        self.output = nn.Linear(self.model.config.hidden_size, output_size).to(device)

    # What happens when passing input into the model.
    def forward(self, sents, locs):
        sents = self.model(sents)[0]
        abbs = torch.stack([sents[n, idx, :] for n, idx in enumerate(locs)])
        return self.output(abbs)




In [3]:
torch.cuda.is_available()

True

In [16]:
import numpy as np

class MedalDatasetTokenizer(torch.utils.data.Dataset):

    def __init__(self, df, tokenizer, dictionary_file, max_length=256, device='cpu'):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.device = device
        self.df = df
        label_df = pd.read_csv(dictionary_file, sep='\t', index_col = "EXPANSION")
        self.label_ser = label_df["LABEL"].squeeze()


    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idxs):
        
        batch_df = self.df.iloc[idxs]
        
        ### Locations used in MeDAL paper
        locs = batch_df['LOCATION'].values

        ### Locations of tokens corresponding to abbreviations
        # locs = batch_df['LONG_LOC'].values+1

        # locs = np.zeros(batch_df['LOCATION'].size, dtype = int)
        label_strings = batch_df['LABEL'].values
        labels = self.label_ser[label_strings].to_numpy()

        batch_encode = self.tokenizer(batch_df['TEXT'].tolist(), max_length=self.max_length, \
                    padding=True, truncation = True)

        tokenized = batch_encode['input_ids']
        return torch.tensor(tokenized).to(self.device), torch.tensor(locs).to(self.device), \
            torch.tensor(labels).to(self.device)

In [5]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from datetime import datetime
import os
from time import time 
import pandas as pd

In [19]:
def train_loop(train_data, model, loss_fn, optimizer, train_loader, max = -1):
    
    size = len(train_data)

    # Switches model to training mode.
    model.train()
    
    # Stores loss at each step.
    loss_list = []

    # For computing accuracy
    correct = 0

    for batch, idx in enumerate(tqdm(train_loader)):

        X = train_data[idx][0]
        loc = train_data[idx][1]
        y = train_data[idx][2]

        # Compute prediction and loss
        pred = model(X, loc)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record the loss and accuracy
        loss_value = loss.item()
        batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        correct += batch_correct
        loss_list.append(loss_value)

        # Terminate early for testing purposes
        if max > 0:
            if batch >= max:
                print("\nMax iterations reached.")
                break

        # Periodically print the loss
        if batch % 100 == 0 and batch != 0:
            print(f"\nBatch loss: {loss_value:>7f}")
    
    loss_list = np.array(loss_list)
    mean_loss = np.mean(loss_list)
    accuracy = correct/size
    print(f"Accuracy: {accuracy:>3f} | Average Loss: {mean_loss:>7f}\n")
    
    return mean_loss, accuracy

# Tests the model on the validation or test data
def valid_loop(valid_data, model, loss_fn, valid_loader, max = -1):

    # Switches model to evaluation mode
    model.eval()

    size = len(valid_data)
    loss_list = [] 
    correct = 0

    with torch.no_grad():
        for batch, idx in tqdm(enumerate(valid_loader)):
        # for batch, idx in enumerate(valid_loader):

            X = valid_data[idx][0]
            loc = valid_data[idx][1]
            y = valid_data[idx][2]

            pred = model(X, loc)

            loss_list.append(loss_fn(pred, y).item())
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

            if max > 0:
                if batch > max:
                    break

    valid_loss = np.mean(np.array(loss_list))
    correct /= size
    print(f"Validation| \nAccuracy: {correct:>3f} | Average loss: {valid_loss:>8f} \n")
    
    return valid_loss, correct

In [7]:
# Save the model's state_dict in its current state. 
# Saved file name records current time and epoch number
def save_model(model, save_dir):
    now = datetime.now()
    time_formatted = now.strftime("%d")+"_"+now.strftime("%H")+"_"+now.strftime("%M")
    torch.save(model.state_dict(), save_dir + f"_{time_formatted}_StateDict.pt")
    print("Model saved\n")

In [17]:
###
### Train the model
###

# Use a GPU if available
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 

### Use a CPU for more clear error messages.
# dev = "cpu"

device = torch.device(dev)

# Set max to a positive integer to terminate training early.
# For debuggin purposes.
max = -1

tokenizer = BERT_TOKENIZER

# Data
num_abbr = "two_abbr"
folder = "drive/MyDrive/Bootcamp"
dictionary_file = f"{folder}/{num_abbr}/dict.txt"

train_df = pd.read_csv(f"{folder}/{num_abbr}/train_long_loc.csv")
train_data = MedalDatasetTokenizer(train_df, tokenizer, dictionary_file, device = device)

valid_df = pd.read_csv(f"{folder}/{num_abbr}/valid_long_loc.csv")
valid_data = MedalDatasetTokenizer(valid_df, tokenizer, dictionary_file, device = device)

# Hyperparameters
learning_rate = 4e-5 # Trains faster with similar acc as original 2e-5
batch_size = 16
epochs = 15

### Models 
output_size = 25 # Should be set to the size of the dictionary

# model = Bert(output_size, device)
model = Electra(output_size=output_size, size = 'small', device=device)

### Load a saved model. The correct model above must be initialized.
# path = f"{folder}/saves/eight_abbr_epoch3_27_23_21_StateDict.pt"
# model.load_state_dict(torch.load(path))

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
loss_fn = nn.CrossEntropyLoss()

# Train the model
for t in range(epochs):
    print(f"\nEpoch {t+1}\n-------------------------------")
    
    # Randomly breaks the data into batches
    train_loader = DataLoader(
        range(len(train_data)), 
        shuffle=True, 
        batch_size=batch_size
    )
    valid_loader = DataLoader(
        range(len(valid_data)), 
        shuffle=True, 
        batch_size=batch_size
    )

    start = time()
    train_loss, train_accuracy = train_loop(train_data, model, loss_fn, optimizer, train_loader, max = max)
    end = time()
    print(f"Training time: {end-start:>0.1f} sec\n")

    valid_loss, valid_accuracy = valid_loop(valid_data, model, loss_fn, valid_loader, max = max)

    # Write the loss and accuracy to a file
    with open(f"{folder}/saves/loss.txt", "a") as file:
        file.writelines(f"\n{t+1},{train_loss},{train_accuracy},{valid_loss},{valid_accuracy}")

    save_model(model, f"{folder}/saves/{num_abbr}_epoch{t+1}")

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  0%|          | 0/294 [00:00<?, ?it/s]


Epoch 1
-------------------------------


 34%|███▍      | 101/294 [00:36<01:08,  2.80it/s]


Batch loss: 2.906245


 68%|██████▊   | 201/294 [01:12<00:33,  2.75it/s]


Batch loss: 2.685986


100%|██████████| 294/294 [01:46<00:00,  2.77it/s]
0it [00:00, ?it/s]

Accuracy: 0.223121 | Average Loss: 2.702567

Training time: 106.3 sec



97it [00:27,  3.53it/s]


Validation| 
Accuracy: 0.492248 | Average loss: 1.744218 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 2
-------------------------------


 34%|███▍      | 101/294 [00:36<01:12,  2.67it/s]


Batch loss: 1.313341


 68%|██████▊   | 201/294 [01:12<00:34,  2.73it/s]


Batch loss: 0.790381


100%|██████████| 294/294 [01:46<00:00,  2.77it/s]
0it [00:00, ?it/s]

Accuracy: 0.624441 | Average Loss: 1.318547

Training time: 106.0 sec



97it [00:27,  3.54it/s]


Validation| 
Accuracy: 0.699612 | Average loss: 1.011978 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 3
-------------------------------


 34%|███▍      | 101/294 [00:36<01:08,  2.83it/s]


Batch loss: 0.957758


 68%|██████▊   | 201/294 [01:11<00:33,  2.80it/s]


Batch loss: 0.569170


100%|██████████| 294/294 [01:44<00:00,  2.80it/s]
0it [00:00, ?it/s]

Accuracy: 0.794976 | Average Loss: 0.751690

Training time: 104.9 sec



97it [00:27,  3.55it/s]


Validation| 
Accuracy: 0.795220 | Average loss: 0.687936 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 4
-------------------------------


 34%|███▍      | 101/294 [00:36<01:14,  2.59it/s]


Batch loss: 0.857162


 68%|██████▊   | 201/294 [01:11<00:33,  2.74it/s]


Batch loss: 0.322483


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.857994 | Average Loss: 0.475670

Training time: 104.4 sec



97it [00:27,  3.50it/s]


Validation| 
Accuracy: 0.814599 | Average loss: 0.613867 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 5
-------------------------------


 34%|███▍      | 101/294 [00:36<01:08,  2.81it/s]


Batch loss: 0.243932


 68%|██████▊   | 201/294 [01:11<00:32,  2.89it/s]


Batch loss: 0.504643


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.890781 | Average Loss: 0.358516

Training time: 104.4 sec



97it [00:27,  3.54it/s]


Validation| 
Accuracy: 0.827519 | Average loss: 0.570200 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 6
-------------------------------


 34%|███▍      | 101/294 [00:36<01:06,  2.90it/s]


Batch loss: 0.265620


 68%|██████▊   | 201/294 [01:11<00:33,  2.79it/s]


Batch loss: 0.124233


100%|██████████| 294/294 [01:44<00:00,  2.83it/s]
0it [00:00, ?it/s]

Accuracy: 0.910794 | Average Loss: 0.260168

Training time: 104.0 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.846253 | Average loss: 0.516693 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 7
-------------------------------


 34%|███▍      | 101/294 [00:36<01:07,  2.86it/s]


Batch loss: 0.300805


 68%|██████▊   | 201/294 [01:10<00:31,  2.93it/s]


Batch loss: 0.151264


100%|██████████| 294/294 [01:43<00:00,  2.83it/s]
0it [00:00, ?it/s]

Accuracy: 0.929104 | Average Loss: 0.218219

Training time: 104.0 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.855943 | Average loss: 0.500714 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 8
-------------------------------


 34%|███▍      | 101/294 [00:36<01:10,  2.73it/s]


Batch loss: 0.260172


 68%|██████▊   | 201/294 [01:11<00:33,  2.81it/s]


Batch loss: 0.033072


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.940600 | Average Loss: 0.179919

Training time: 104.4 sec



97it [00:27,  3.54it/s]


Validation| 
Accuracy: 0.856589 | Average loss: 0.520990 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 9
-------------------------------


 34%|███▍      | 101/294 [00:35<01:07,  2.85it/s]


Batch loss: 0.080694


 68%|██████▊   | 201/294 [01:11<00:33,  2.81it/s]


Batch loss: 0.243101


100%|██████████| 294/294 [01:43<00:00,  2.83it/s]
0it [00:00, ?it/s]

Accuracy: 0.943368 | Average Loss: 0.162641

Training time: 103.8 sec



97it [00:27,  3.54it/s]


Validation| 
Accuracy: 0.848837 | Average loss: 0.570280 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 10
-------------------------------


 34%|███▍      | 101/294 [00:35<01:10,  2.75it/s]


Batch loss: 0.044176


 68%|██████▊   | 201/294 [01:11<00:31,  2.93it/s]


Batch loss: 0.076815


100%|██████████| 294/294 [01:43<00:00,  2.84it/s]
0it [00:00, ?it/s]

Accuracy: 0.954865 | Average Loss: 0.140444

Training time: 103.6 sec



97it [00:27,  3.51it/s]


Validation| 
Accuracy: 0.841731 | Average loss: 0.570285 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 11
-------------------------------


 34%|███▍      | 101/294 [00:36<01:10,  2.75it/s]


Batch loss: 0.187778


 68%|██████▊   | 201/294 [01:11<00:33,  2.81it/s]


Batch loss: 0.263021


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.961252 | Average Loss: 0.130489

Training time: 104.4 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.854005 | Average loss: 0.554476 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 12
-------------------------------


 34%|███▍      | 101/294 [00:35<01:06,  2.90it/s]


Batch loss: 0.411526


 68%|██████▊   | 201/294 [01:11<00:32,  2.87it/s]


Batch loss: 0.113710


100%|██████████| 294/294 [01:44<00:00,  2.83it/s]
0it [00:00, ?it/s]

Accuracy: 0.959123 | Average Loss: 0.120500

Training time: 104.0 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.868863 | Average loss: 0.523249 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 13
-------------------------------


 34%|███▍      | 101/294 [00:35<01:08,  2.82it/s]


Batch loss: 0.019488


 68%|██████▊   | 201/294 [01:11<00:33,  2.75it/s]


Batch loss: 0.010723


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.977432 | Average Loss: 0.084329

Training time: 104.4 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.866279 | Average loss: 0.537227 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 14
-------------------------------


 34%|███▍      | 101/294 [00:35<01:10,  2.73it/s]


Batch loss: 0.013719


 68%|██████▊   | 201/294 [01:10<00:31,  2.96it/s]


Batch loss: 0.074761


100%|██████████| 294/294 [01:43<00:00,  2.84it/s]
0it [00:00, ?it/s]

Accuracy: 0.977220 | Average Loss: 0.074018

Training time: 103.7 sec



97it [00:27,  3.55it/s]


Validation| 
Accuracy: 0.869509 | Average loss: 0.550419 



  0%|          | 0/294 [00:00<?, ?it/s]

Model saved


Epoch 15
-------------------------------


 34%|███▍      | 101/294 [00:35<01:05,  2.95it/s]


Batch loss: 0.008758


 68%|██████▊   | 201/294 [01:11<00:35,  2.65it/s]


Batch loss: 0.048343


100%|██████████| 294/294 [01:44<00:00,  2.82it/s]
0it [00:00, ?it/s]

Accuracy: 0.986161 | Average Loss: 0.056378

Training time: 104.2 sec



97it [00:27,  3.52it/s]


Validation| 
Accuracy: 0.857881 | Average loss: 0.576169 

Model saved



In [20]:
###
### Test the model
###

# Use a GPU if available
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
### Use a CPU for more clear error messages.
# dev = "cpu"
device = torch.device(dev)

# Data
num_abbr = "two_abbr"
folder = "drive/MyDrive/Bootcamp"
dictionary_file = f"{folder}/{num_abbr}/dict.txt"

tokenizer = BERT_TOKENIZER
test_df = pd.read_csv(f"{folder}/{num_abbr}/test_long_loc.csv")
test_data = MedalDatasetTokenizer(test_df, tokenizer, dictionary_file, device = device)

### Models 
output_size = 25 # Should be set to the size of the dictionary

#model = Bert(output_size, device)
model = Electra(output_size=output_size, size = 'small', device=device)
#model = Electra(output_size=output_size, size = 'base', device=device)

### Load a saved model. The correct model above must be initialized.
path = f"{folder}/saves/Finished/ElectraSmall_TwoAbbr_OldLoc_Epoch12.pt"
model.load_state_dict(torch.load(path))
#model = torch.load(path)

batch_size = 16
test_loader = DataLoader(
        range(len(test_data)), 
        shuffle=True, 
        batch_size=batch_size
    )

loss_fn = nn.CrossEntropyLoss()
test_loss, test_accuracy = valid_loop(test_data, model, loss_fn, test_loader)
test_loss, test_accuracy

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
98it [00:28,  3.49it/s]

Validation| 
Accuracy: 0.881302 | Average loss: 0.460943 






(0.4609425602984444, 0.8813018506700702)