In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
TRAIN_PATH = 'data/data1/train.json'
VALID_PATH = 'data/data1/dev.json'

import json
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification

In [3]:
with open(TRAIN_PATH, 'r') as file:
    json_list = json.load(file)
json_list[0]

{'acronym': 20,
 'expansion': 'secrecy rate',
 'id': 'TR-0',
 'tokens': ['In',
  'summary',
  ',',
  'it',
  'is',
  'evident',
  'that',
  'their',
  'complexities',
  'are',
  'in',
  'increasing',
  'order',
  ':',
  'leakage',
  '-',
  'based',
  ',',
  'Max',
  '-',
  'SR',
  ',',
  'and',
  'generalized',
  'EDAS',
  '.']}

In [4]:
def read_corpus(file_name, tokenizer, max_length=128, word_sense_dict=None):
    data = []
    is_dict_provided = word_sense_dict is not None
    word_sense_dict = word_sense_dict or {}

    with open(file_name, 'r') as file:
        json_list = json.load(file)[:1000]
    for item in json_list:
        word = item['tokens'][item['acronym']]
        sense = item['expansion']
        if not is_dict_provided:
            word_sense_dict.setdefault(word, set()).add(sense)

    for item in json_list:
        word = item['tokens'][item['acronym']]
        sense = item['expansion']
        sentence = ' '.join(item['tokens'])
        obs = []

        # Update the word_sense_dict with new senses if they don't exist
        if word not in word_sense_dict:
            word_sense_dict[word] = set()
        word_sense_dict[word].add(sense)

        # Tokenize immediately to save processing time during training
        pos_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + sense,
                              padding='max_length', max_length=max_length,
                              truncation=True, return_tensors='pt')
        obs.append((pos_input['input_ids'], pos_input['attention_mask'], 1))
        if word not in word_sense_dict.keys():
            print(word)
        for word_sense in word_sense_dict[word]:
            if word_sense != sense:
                neg_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + word_sense,
                                      padding='max_length', max_length=max_length,
                                      truncation=True, return_tensors='pt')
                obs.append((neg_input['input_ids'], neg_input['attention_mask'], 0))
        data.append(obs)
    return data, word_sense_dict

In [5]:
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# read_corpus(TRAIN_PATH, tokenizer)[0]

In [6]:
class WSDDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [7]:
def collate_fn(batch):
    # Flatten the batch
    input_ids = [obs[0] for item in batch for obs in item]
    attention_masks = [obs[1] for item in batch for obs in item]
    labels = [obs[2] for item in batch for obs in item]

    # Pad input_ids and attention_masks
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)

    # Find the maximum number of observations in any item
    max_obs = max(len(item) for item in batch)

    # Pad each item to have the same number of observations
    padded_input_ids = []
    padded_attention_masks = []
    padded_labels = []

    start_idx = 0
    for item in batch:
        end_idx = start_idx + len(item)
        num_obs = len(item)
        # Pad if necessary
        if num_obs < max_obs:
            num_padding = max_obs - num_obs

            input_padding = torch.zeros(num_padding, input_ids.shape[1], input_ids.shape[2], dtype=input_ids.dtype)
            attention_padding = torch.zeros(num_padding, attention_masks.shape[1], attention_masks.shape[2], dtype=attention_masks.dtype)

            padded_input_ids.append(torch.cat([
                input_ids[start_idx:end_idx],
                input_padding
            ], dim=0))
            padded_attention_masks.append(torch.cat([
                attention_masks[start_idx:end_idx],
                attention_padding
            ], dim=0))
            padded_labels.append(torch.cat([
                labels[start_idx:end_idx],
                torch.full((num_padding,), -100, dtype=torch.long)  # Use a dummy label that will be ignored
            ], dim=0))
        else:
            padded_input_ids.append(input_ids[start_idx:end_idx])
            padded_attention_masks.append(attention_masks[start_idx:end_idx])
            padded_labels.append(labels[start_idx:end_idx])
        start_idx = end_idx

    # Convert lists to tensors
    padded_input_ids = torch.stack(padded_input_ids)
    padded_attention_masks = torch.stack(padded_attention_masks)
    padded_labels = torch.stack(padded_labels)

    return padded_input_ids, padded_attention_masks, padded_labels

In [8]:
def train(model, dataloader, optimizer, criterion, device, print_every=1):
    model.train()
    total_loss = 0

    # Print the total number of batches
    total_batches = len(dataloader)
    print(f"Total number of batches: {total_batches}")

    for batch_idx, (inputs_ids, attention_masks, labels) in enumerate(dataloader):
        # Flatten the batch
        batch_size, num_observations, _, seq_length = inputs_ids.size()
        inputs_ids = inputs_ids.view(batch_size * num_observations, seq_length)
        attention_masks = attention_masks.view(batch_size * num_observations, seq_length)
        labels = labels.view(-1)  # Flatten labels

        inputs_ids = inputs_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device).long()

        optimizer.zero_grad()
        outputs = model(inputs_ids, attention_masks)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Print intermediate results every `print_every` batches
        if (batch_idx + 1) % print_every == 0:
            print(f"Batch {batch_idx + 1}/{total_batches}")
            print(f"Loss: {loss.item()}")
            print("-" * 80)

    return total_loss / len(dataloader)

In [9]:
def evaluate(model, dataloader, tokenizer, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs_ids, attention_masks, labels in dataloader:
            # Flatten the batch
            batch_size, num_observations, _, seq_length = inputs_ids.size()
            inputs_ids = inputs_ids.view(batch_size * num_observations, seq_length)
            attention_masks = attention_masks.view(batch_size * num_observations, seq_length)
            labels = labels.view(-1)  # Flatten labels

            inputs_ids = inputs_ids.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)

            outputs = model(input_ids=inputs_ids, attention_mask=attention_masks)
            predictions = torch.argmax(outputs.logits, dim=1)

            for i in range(len(predictions)):
                if labels[i] != -100:
                    pred_choice = predictions[i].item()
                    true_choice = labels[i].item()

                    # Decode input ids to text
                    # input_ids = inputs_ids[i]
                    # choice_texts = [tokenizer.decode(input_ids[j], skip_special_tokens=True) for j in
                    #                 range(input_ids.size(0))]
                    # Print prediction and ground truth
                    # print(f"Prediction: {choices[pred_choice]} - {choice_texts[pred_choice]}")
                    # print(f"Ground Truth: {choices[true_choice]} - {choice_texts[true_choice]}")
                    # print()

                    if pred_choice == true_choice:
                        correct += 1
                    total += 1

    accuracy = correct / total
    return accuracy

In [None]:
def main(params):
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    train_data, train_word_sense_dict = read_corpus(TRAIN_PATH, tokenizer, max_length=params['max_length'])
    valid_data, _ = read_corpus(VALID_PATH, tokenizer, max_length=params['max_length'],
                                word_sense_dict=train_word_sense_dict)
    print('Finished reading data!')
    # %%
    train_dataset = WSDDataset(train_data)
    valid_dataset = WSDDataset(valid_data)
    # %%
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)
    print('Finished loading data!')
    # %%
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.to(device)
    # %%
    # Zero-shot evaluation
    # zero_shot_valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
    # print(f"Zero-shot validation accuracy: {zero_shot_valid_accuracy}")
    
    # %%
    # Fine-tuning
    epochs = params['epoch']
    optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_loss = train(model, train_loader, optimizer, criterion, device)
        valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"Training loss: {train_loss}")
        print(f"Validation accuracy: {valid_accuracy}")
    
    # %%
    # Finetuned evaluation
    finetuned_valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
    print(f"Finetuned validation accuracy: {finetuned_valid_accuracy}")

In [10]:
params = {
    'max_length': 128,
    'batch_size': 2,
    'learning_rate': 3e-5,
    'epoch': 3
}
main(params)