In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.3-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m85.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.2 transformers-4.27.3


In [2]:
import pandas as pd
import nltk
from nltk.corpus import wordnet
nltk.download('wordnet')
import numpy as np
from transformers import DistilBertTokenizer, TFDistilBertModel, AutoTokenizer, TrainingArguments, Trainer, DistilBertModel
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import torch.nn as nn

[nltk_data] Downloading package wordnet to /root/nltk_data...


In [3]:
path = "sample_data/toyset.csv"
df = pd.read_csv(path, dtype="string")
df['Definition'] = df['Definition'].astype(str)
df = df[['Word', 'Definition']]
df

Unnamed: 0,Word,Definition
0,Geographical,"""Of or pertaining to geography."""
1,Inextricableness,"""The state of being inextricable."""
2,Papuars,"""The native black race of Papua or New Guinea ..."
3,dark-coated,covered with dark hair
4,Cesura,"""See Caesura."""
...,...,...
233,olive,a tree of some other species of olea or of som...
234,olive,evergreen tree cultivated in the mediterranean...
235,olive,an evergreen tree olea europaea cultivated sin...
236,olive,the tree has been cultivated for its fruit for...


In [4]:
'''Convert classes to numbers'''
word_dict = {} 
i = 0
for w in df['Word'].unique():
    word_dict[w] = i
    i += 1

'''Convert numbers back to words'''
idx2word = {v:k for k,v in word_dict.items()}

In [5]:
df_train, df_test = train_test_split(df[['Definition','Word']], test_size=0.2)
df_test, df_val = train_test_split(df_test[['Definition','Word']], test_size=0.5)

In [6]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

train_enc = tokenizer(df_train['Definition'].to_list(), padding=True, truncation=True, max_length=128)
test_enc = tokenizer(df_test['Definition'].to_list(), padding=True, truncation=True, max_length=128)
val_enc = tokenizer(df_val['Definition'].to_list(), padding=True, truncation=True, max_length=128)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [7]:
class RevDictDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        self.labels = self.labels.to_list()
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)

In [8]:
'''One hot encoding of classes'''
train_label_enum = {k:j+1 for j, k in enumerate(df_train['Word'].unique())}
train_label_enum["<unk>"] = 0
train_num_labels = len(train_label_enum)
idx2token = {idx: token for token, idx in train_label_enum.items()}
df_train['labels'] = df_train['Word'].apply(lambda x: [1.0 if train_label_enum[x]==i else 0.0 for i in range(train_num_labels)])

labels = []
for word in list(df_val['Word']):
  if word in train_label_enum:
    labels.append(train_label_enum[word])
  else:
    labels.append(train_label_enum["<unk>"])
df_val['labels'] = labels

In [9]:
train_dataset = RevDictDataset(train_enc, df_train['labels'])
val_dataset = RevDictDataset(val_enc, df_val['labels'])

In [10]:
class BLmodel(nn.Module):
  def __init__(self):
    super(BLmodel, self).__init__()
    self.bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
    # Define LSTM layers
    self.lstm_layer_1 = nn.LSTM(input_size=768, hidden_size=256, num_layers=1, batch_first=True, bidirectional=False)
    self.lstm_layer_2 = nn.LSTM(input_size=256, hidden_size=128, num_layers=1, batch_first=True, bidirectional=False)
    self.lstm_layer_3 = nn.LSTM(input_size=128, hidden_size=64, num_layers=1, batch_first=True, bidirectional=False)
    self.lstm_layer_4 = nn.LSTM(input_size=64, hidden_size=32, num_layers=1, batch_first=True, bidirectional=False)

    # Define output layer
    self.output_layer = nn.Linear(32, len(train_label_enum))

    # Define softmax activation
    self.softmax = nn.Softmax(dim=1)

  def forward(self, input_ids, attention_mask=None):
    # Get BERT embeddings
    outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
    bert_embedding = outputs[0]
    # Get LSTM outputs
    lstm_output_1, _ = self.lstm_layer_1(bert_embedding)
    lstm_output_2, _ = self.lstm_layer_2(lstm_output_1)
    lstm_output_3, _ = self.lstm_layer_3(lstm_output_2)
    lstm_output_4, _ = self.lstm_layer_4(lstm_output_3)

    # Get output and apply softmax
    output = self.output_layer(lstm_output_4[:, -1, :])
    output = self.softmax(output)
    return output

In [11]:
# Instantiate the DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Define the training function
def train(model, dataloader, optimizer, num_epochs):
  model.train()
  for epoch in range(num_epochs):
    for batch in dataloader:
      input_ids = batch['input_ids']
      labels = batch['labels']
      optimizer.zero_grad()
      outputs = model(input_ids)
      loss = nn.CrossEntropyLoss()(outputs.view(-1, len(train_label_enum)), labels)
      loss.backward()
      optimizer.step()
    print(f"Epoch {epoch+1} complete. Loss: {loss.item()}")

# Define the training parameters
model = BLmodel()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
num_epochs = 5
train(model, train_loader, optimizer, 5)

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 1 complete. Loss: 4.370944499969482
Epoch 2 complete. Loss: 4.371566295623779
Epoch 3 complete. Loss: 4.37577486038208
Epoch 4 complete. Loss: 4.377648830413818
Epoch 5 complete. Loss: 4.374609470367432


In [12]:
def evaluate(model, loader):
  model.eval()
  val_predictions = []
  for batch in loader:
      input_ids = batch['input_ids']
      labels = batch['labels']
      outputs = model(input_ids)
      # print(outputs)
      _, predicted = torch.max(outputs, dim=1)
      val_predictions.extend(predicted.flatten().tolist())
  return val_predictions

val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
val_predictions = evaluate(model, val_loader)


In [13]:
df_val['predicted'] = [idx2token[idx] for idx in val_predictions]
df_val

Unnamed: 0,Definition,Word,labels,predicted
181,"""A leopard.""",Pardale,0,overhaul
24,tropical genus of small trees or shrubs,genus_Crateva,0,overhaul
235,an evergreen tree olea europaea cultivated sin...,olive,18,overhaul
218,difficult to bear burdensome oppressive laws,oppressive,21,overhaul
96,"""A pewfellow.""",Puefellow,0,overhaul
20,"""To tame.""",Entame,0,overhaul
101,a zoologist who studies fishes,ichthyologist,63,overhaul
91,"""From without inward; toward the inside; as t...",Inboard,52,overhaul
226,"""In a pouting or a sullen manner.""",Poutingly,0,overhaul
167,to sunburn,burn,19,overhaul
