In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
TRAIN_PATH = '/content/drive/MyDrive/NLP_WSD/data1/train.json'
VALID_PATH = '/content/drive/MyDrive/NLP_WSD/data1/dev.json'

import json
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel

In [3]:
with open(TRAIN_PATH, 'r') as file:
    json_list = json.load(file)
json_list[0]

{'acronym': 20,
 'expansion': 'secrecy rate',
 'id': 'TR-0',
 'tokens': ['In',
  'summary',
  ',',
  'it',
  'is',
  'evident',
  'that',
  'their',
  'complexities',
  'are',
  'in',
  'increasing',
  'order',
  ':',
  'leakage',
  '-',
  'based',
  ',',
  'Max',
  '-',
  'SR',
  ',',
  'and',
  'generalized',
  'EDAS',
  '.']}

In [4]:
def read_corpus(file_name, tokenizer, max_length=128, word_sense_dict=None):
    data = []
    is_dict_provided = word_sense_dict is not None
    word_sense_dict = word_sense_dict or {}

    with open(file_name, 'r') as file:
        json_list = json.load(file)[:5000]
    for item in json_list:
        word = item['tokens'][item['acronym']]
        sense = item['expansion']
        if not is_dict_provided:
            word_sense_dict.setdefault(word, set()).add(sense)

    for index, item in enumerate(json_list):
        word = item['tokens'][item['acronym']]
        sense = item['expansion']
        sentence = ' '.join(item['tokens'])

        # Update the word_sense_dict with new senses if they don't exist
        if word not in word_sense_dict:
            word_sense_dict[word] = set()
        word_sense_dict[word].add(sense)

        # Tokenize immediately to save processing time during training
        pos_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + sense,
                              padding='max_length', max_length=max_length,
                              truncation=True, return_tensors='pt')
        data.append((pos_input['input_ids'], pos_input['attention_mask'], 1))
        if word not in word_sense_dict.keys():
            print(word)
        for word_sense in word_sense_dict[word]:
            if word_sense != sense:
                neg_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + word_sense,
                                      padding='max_length', max_length=max_length,
                                      truncation=True, return_tensors='pt')
                data.append((neg_input['input_ids'], neg_input['attention_mask'], 0))
    return data, word_sense_dict

In [5]:
class WSDDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [6]:
def collate_fn(batch):
    # Flatten the batch
    input_ids = [item[0] for item in batch]
    attention_masks = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)

    return input_ids, attention_masks, labels

In [7]:
def train(model, dataloader, optimizer, criterion, device, print_every=10):
    model.train()
    total_loss = 0

    # Print the total number of batches
    total_batches = len(dataloader)
    print(f"Total number of batches: {total_batches}")

    for batch_idx, (inputs_ids, attention_masks, labels) in enumerate(dataloader):
        # Flatten the batch
        batch_size, _, seq_length = inputs_ids.size()
        inputs_ids = inputs_ids.view(batch_size, seq_length)
        attention_masks = attention_masks.view(batch_size, seq_length)
        labels = labels.view(-1)  # Flatten labels

        inputs_ids = inputs_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device).long()

        optimizer.zero_grad()
        outputs = model(inputs_ids, attention_masks)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        # Print intermediate results every `print_every` batches
        if (batch_idx + 1) % print_every == 0:
            print(f"Batch {batch_idx + 1}/{total_batches}")
            print(f"Loss: {loss.item()}")
            print("-" * 80)

    return total_loss / len(dataloader)

In [8]:
def evaluate(model, dataloader, tokenizer, device, print_every=10):
    model.eval()
    correct = 0
    total = 0
    total_batches = len(dataloader)
    with torch.no_grad():
        for i, (inputs_ids, attention_masks, labels) in enumerate(dataloader):
            # Flatten the batch
            batch_size, _, seq_length = inputs_ids.size()
            inputs_ids = inputs_ids.view(batch_size, seq_length)
            attention_masks = attention_masks.view(batch_size, seq_length)
            labels = labels.view(-1)  # Flatten labels

            inputs_ids = inputs_ids.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)

            outputs = model(input_ids=inputs_ids, attention_mask=attention_masks)
            predictions = torch.argmax(outputs, dim=1)

            for j in range(len(predictions)):
                if labels[j] != -100:
                    pred_choice = predictions[j].item()
                    true_choice = labels[j].item()

                    if pred_choice == true_choice:
                        correct += 1
                    total += 1

            if (i + 1) % print_every == 0:
                print(f"Processed {i + 1}/{total_batches} batches.")
                current_accuracy = (correct / total) if total > 0 else 0
                print(f"Current Accuracy: {current_accuracy:.4f}")

    accuracy = correct / total
    return accuracy

In [9]:
class BertWSDModel(nn.Module):
    def __init__(self, bert_model):
        super(BertWSDModel, self).__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]   # Pick the first element(CLS label) from each sequence
        logits = self.linear(cls_output)
        return logits

In [10]:
params = {
    'max_length': 512,
    'batch_size': 100,
    'learning_rate': 1e-5,
    'epoch': 3
}
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

bert_model = BertModel.from_pretrained("bert-base-uncased")
model = BertWSDModel(bert_model).to(device)
model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertWSDModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [11]:
train_data, train_word_sense_dict = read_corpus(TRAIN_PATH, tokenizer, max_length=params['max_length'])
valid_data, _ = read_corpus(VALID_PATH, tokenizer, max_length=params['max_length'],
                            word_sense_dict=train_word_sense_dict)
print('Finished reading data!')

train_dataset = WSDDataset(train_data)
valid_dataset = WSDDataset(valid_data)

train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)
print('Finished loading data!')

Finished reading data!
Finished loading data!


In [12]:
# Zero-shot evaluation
# zero_shot_valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
# print(f"Zero-shot validation accuracy: {zero_shot_valid_accuracy}")

In [13]:
# Fine-tuning
epochs = params['epoch']
optimizer = optim.AdamW(model.parameters(), lr=params['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=-100)
for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"Training loss: {train_loss}")
    print(f"Validation accuracy: {valid_accuracy}")

Total number of batches: 140
Batch 10/140
Loss: 0.6385166645050049
--------------------------------------------------------------------------------
Batch 20/140
Loss: 0.6717284917831421
--------------------------------------------------------------------------------
Batch 30/140
Loss: 0.6210108995437622
--------------------------------------------------------------------------------
Batch 40/140
Loss: 0.6683268547058105
--------------------------------------------------------------------------------
Batch 50/140
Loss: 0.5648550987243652
--------------------------------------------------------------------------------
Batch 60/140
Loss: 0.5725706815719604
--------------------------------------------------------------------------------
Batch 70/140
Loss: 0.5316883325576782
--------------------------------------------------------------------------------
Batch 80/140
Loss: 0.5498493313789368
--------------------------------------------------------------------------------
Batch 90/140
Loss: 