In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
TRAIN_PATH = '/content/drive/MyDrive/NLP_WSD/data2/AnonymizedClinicalAbbreviationsAndAcronymsDataSet.txt'
MODEL_PATH = '/content/drive/MyDrive/NLP_WSD/model_weights/data2_stratified.pth'

import json
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel

In [3]:
with open(TRAIN_PATH, 'r', encoding='utf-8', errors='ignore') as file:
    contents = file.readlines()

In [4]:
from sklearn.model_selection import train_test_split
from collections import Counter

# Extract word senses to use for stratification
word_senses = [item.split('|')[1] for item in contents]

# Count occurrences of each word sense
word_sense_counts = Counter(word_senses)

# Oversample rare word senses
min_count = 2  # Minimum number of instances required per word sense
augmented_contents = []
augmented_word_senses = []

for item, sense in zip(contents, word_senses):
    count = word_sense_counts[sense]
    if count < min_count:
        # Duplicate rare examples to reach the minimum count
        augmented_contents.extend([item] * (min_count - count + 1))
        augmented_word_senses.extend([sense] * (min_count - count + 1))
    augmented_contents.append(item)
    augmented_word_senses.append(sense)

In [5]:
# Perform stratified split based on word senses
train_lines, val_test_lines, train_word_senses, val_test_word_senses = train_test_split(
    augmented_contents, augmented_word_senses, test_size=0.4, random_state=42, stratify=augmented_word_senses)

val_lines, test_lines = train_test_split(
    val_test_lines, test_size=0.5, random_state=42)  # 0.25 * 0.8 = 0.2

# Verify the sizes
print(len(train_lines), len(val_lines), len(test_lines))

22604 7535 7535


In [6]:
def read_corpus(content, tokenizer, max_length=128, word_sense_dict=None):
    data = []
    is_dict_provided = word_sense_dict is not None
    word_sense_dict = word_sense_dict or {}

    for item in content:
        split_item = item.split('|')
        word = split_item[0]
        sense = split_item[1]
        if not is_dict_provided:
            word_sense_dict.setdefault(word, set()).add(sense)

    for index, item in enumerate(content):
        split_item = item.split('|')
        word = split_item[0]
        sense = split_item[1]
        sentence = split_item[-1]

        # Update the word_sense_dict with new senses if they don't exist
        if word not in word_sense_dict:
            word_sense_dict[word] = set(sense)
        else:
            if sense not in word_sense_dict[word]:
                word_sense_dict[word].add(sense)

        # Positive examples
        pos_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + sense,
                              padding='max_length', max_length=max_length,
                              truncation=True, return_tensors='pt')
        data.append((pos_input['input_ids'], pos_input['attention_mask'], 1))

        # Negative examples
        for word_sense in word_sense_dict[word]:
            if word_sense != sense:
                neg_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + word_sense,
                                      padding='max_length', max_length=max_length,
                                      truncation=True, return_tensors='pt')
                data.append((neg_input['input_ids'], neg_input['attention_mask'], 0))
    return data, word_sense_dict

In [7]:
class WSDDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [8]:
def collate_fn(batch):
    input_ids = [item[0] for item in batch]
    attention_masks = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)

    return input_ids, attention_masks, labels

In [9]:
def train(model, dataloader, optimizer, criterion, device, print_every=10):
    model.train()
    total_loss = 0

    total_batches = len(dataloader)
    print(f"Total number of batches: {total_batches}")

    for batch_idx, (inputs_ids, attention_masks, labels) in enumerate(dataloader):
        batch_size, _, seq_length = inputs_ids.size()
        inputs_ids = inputs_ids.view(batch_size, seq_length)
        attention_masks = attention_masks.view(batch_size, seq_length)
        labels = labels.view(-1)

        inputs_ids = inputs_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device).long()

        optimizer.zero_grad()
        outputs = model(inputs_ids, attention_masks)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if (batch_idx + 1) % print_every == 0:
            print(f"Batch {batch_idx + 1}/{total_batches}")
            print(f"Loss: {loss.item()}")
            print("-" * 80)

    return total_loss / len(dataloader)

In [10]:
def evaluate(model, dataloader, tokenizer, device, print_every=10):
    model.eval()
    correct = 0
    total = 0
    total_batches = len(dataloader)
    with torch.no_grad():
        for i, (inputs_ids, attention_masks, labels) in enumerate(dataloader):
            batch_size, _, seq_length = inputs_ids.size()
            inputs_ids = inputs_ids.view(batch_size, seq_length)
            attention_masks = attention_masks.view(batch_size, seq_length)
            labels = labels.view(-1)

            inputs_ids = inputs_ids.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)

            outputs = model(input_ids=inputs_ids, attention_mask=attention_masks)
            predictions = torch.argmax(outputs, dim=1)

            for j in range(len(predictions)):
                if labels[j] != -100:
                    pred_choice = predictions[j].item()
                    true_choice = labels[j].item()

                    if pred_choice == true_choice:
                        correct += 1
                    total += 1

            if (i + 1) % print_every == 0:
                print(f"Processed {i + 1}/{total_batches} batches.")
                current_accuracy = (correct / total) if total > 0 else 0
                print(f"Current Accuracy: {current_accuracy:.4f}")

    accuracy = correct / total
    return accuracy

In [11]:
class BertWSDModel(nn.Module):
    def __init__(self, bert_model):
        super(BertWSDModel, self).__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]   # Pick the first element(CLS label) from each sequence
        logits = self.linear(cls_output)
        return logits

In [12]:
params = {
    'max_length': 512,
    'batch_size': 100,
    'learning_rate': 1e-5,
    'epoch': 3
}
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

bert_model = BertModel.from_pretrained("bert-base-uncased")
model = BertWSDModel(bert_model).to(device)
model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertWSDModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [13]:
train_data, train_word_sense_dict = read_corpus(train_lines, tokenizer, max_length=params['max_length'])
valid_data, _ = read_corpus(val_lines, tokenizer, max_length=params['max_length'], word_sense_dict=train_word_sense_dict)
test_data, _ = read_corpus(test_lines, tokenizer, max_length=params['max_length'], word_sense_dict=train_word_sense_dict)
print('Finished reading data!')

train_dataset = WSDDataset(train_data)
valid_dataset = WSDDataset(valid_data)
test_dataset = WSDDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)
print('Finished loading data!')

Finished reading data!
Finished loading data!


In [14]:
# Zero-shot evaluation
zero_shot_valid_accuracy = evaluate(model, valid_loader, tokenizer, device, print_every=50)
print(f"Zero-shot validation accuracy: {zero_shot_valid_accuracy}")

Processed 1/405 batches.
Current Accuracy: 0.2700
Processed 2/405 batches.
Current Accuracy: 0.2550
Processed 3/405 batches.
Current Accuracy: 0.2500
Processed 4/405 batches.
Current Accuracy: 0.2325
Processed 5/405 batches.
Current Accuracy: 0.2260
Processed 6/405 batches.
Current Accuracy: 0.2150
Processed 7/405 batches.
Current Accuracy: 0.2086
Processed 8/405 batches.
Current Accuracy: 0.2188
Processed 9/405 batches.
Current Accuracy: 0.2189
Processed 10/405 batches.
Current Accuracy: 0.2230
Processed 11/405 batches.
Current Accuracy: 0.2227
Processed 12/405 batches.
Current Accuracy: 0.2225
Processed 13/405 batches.
Current Accuracy: 0.2215
Processed 14/405 batches.
Current Accuracy: 0.2293
Processed 15/405 batches.
Current Accuracy: 0.2293
Processed 16/405 batches.
Current Accuracy: 0.2275
Processed 17/405 batches.
Current Accuracy: 0.2253
Processed 18/405 batches.
Current Accuracy: 0.2228
Processed 19/405 batches.
Current Accuracy: 0.2195
Processed 20/405 batches.
Current Accura

In [15]:
# Fine-tuning
epochs = params['epoch']
optimizer = optim.AdamW(model.parameters(), lr=params['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=-100)
for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    valid_accuracy = evaluate(model, valid_loader, tokenizer, device)
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"Training loss: {train_loss}")
    print(f"Validation accuracy: {valid_accuracy}")

Total number of batches: 1197
Batch 10/1197
Loss: 0.4153246283531189
--------------------------------------------------------------------------------
Batch 20/1197
Loss: 0.4849298596382141
--------------------------------------------------------------------------------
Batch 30/1197
Loss: 0.5022841095924377
--------------------------------------------------------------------------------
Batch 40/1197
Loss: 0.4260963499546051
--------------------------------------------------------------------------------
Batch 50/1197
Loss: 0.4568127691745758
--------------------------------------------------------------------------------
Batch 60/1197
Loss: 0.3837433159351349
--------------------------------------------------------------------------------
Batch 70/1197
Loss: 0.4533541202545166
--------------------------------------------------------------------------------
Batch 80/1197
Loss: 0.3981844186782837
--------------------------------------------------------------------------------
Batch 90/1

In [16]:
# Save model weights
save_path = MODEL_PATH
torch.save(model.state_dict(), save_path)

# Predict

## Predict single sentence

In [None]:
MODEL_PATH = '/content/drive/MyDrive/NLP_WSD/model_weights/data2_stratified.pth'
params = {
    'max_length': 512,
    'batch_size': 100,
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained("bert-base-uncased")
model = BertWSDModel(bert_model).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
# Example sentence
sentence = "This is an example sentence for word sense disambiguation.jjjj"
# Tokenize the input sentence
inputs = tokenizer(sentence, return_tensors='pt', max_length=params['max_length'], truncation=True, padding='max_length')
# Move inputs to the appropriate device
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

# Make predictions
with torch.no_grad():  # Disable gradient calculation
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    predictions = torch.argmax(outputs, dim=-1)

# Convert predictions to numpy array (if needed)
predictions = predictions.cpu().numpy()

print(predictions)

[0]


## Predict in batch

In [1]:
# Define some functions
from google.colab import drive
drive.mount('/content/drive')

import json
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from collections import Counter

TRAIN_PATH = '/content/drive/MyDrive/NLP_WSD/data2/AnonymizedClinicalAbbreviationsAndAcronymsDataSet.txt'

with open(TRAIN_PATH, 'r', encoding='utf-8', errors='ignore') as file:
    contents = file.readlines()

# Extract word senses to use for stratification
word_senses = [item.split('|')[1] for item in contents]

# Count occurrences of each word sense
word_sense_counts = Counter(word_senses)

# Oversample rare word senses
min_count = 2  # Minimum number of instances required per word sense
augmented_contents = []
augmented_word_senses = []

for item, sense in zip(contents, word_senses):
    count = word_sense_counts[sense]
    if count < min_count:
        # Duplicate rare examples to reach the minimum count
        augmented_contents.extend([item] * (min_count - count + 1))
        augmented_word_senses.extend([sense] * (min_count - count + 1))
    augmented_contents.append(item)
    augmented_word_senses.append(sense)

# Perform stratified split based on word senses
train_lines, val_test_lines, train_word_senses, val_test_word_senses = train_test_split(
    augmented_contents, augmented_word_senses, test_size=0.4, random_state=42, stratify=augmented_word_senses)

val_lines, test_lines = train_test_split(
    val_test_lines, test_size=0.5, random_state=42)

# Verify the sizes
print(len(train_lines), len(val_lines), len(test_lines))

def read_corpus(content, tokenizer, max_length=128, word_sense_dict=None):
    data = []
    is_dict_provided = word_sense_dict is not None
    word_sense_dict = word_sense_dict or {}

    for item in content:
        split_item = item.split('|')
        word = split_item[0]
        sense = split_item[1]
        if not is_dict_provided:
            word_sense_dict.setdefault(word, set()).add(sense)

    for index, item in enumerate(content):
        split_item = item.split('|')
        word = split_item[0]
        sense = split_item[1]
        sentence = split_item[-1]

        # Update the word_sense_dict with new senses if they don't exist
        if word not in word_sense_dict:
            word_sense_dict[word] = set(sense)
        else:
            if sense not in word_sense_dict[word]:
                word_sense_dict[word].add(sense)

        # Positive examples
        pos_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + sense,
                              padding='max_length', max_length=max_length,
                              truncation=True, return_tensors='pt')
        data.append((pos_input['input_ids'], pos_input['attention_mask'], 1))

        # Negative examples
        for word_sense in word_sense_dict[word]:
            if word_sense != sense:
                neg_input = tokenizer(sentence + ' [SEP] ' + word + ' [SEP] ' + word_sense,
                                      padding='max_length', max_length=max_length,
                                      truncation=True, return_tensors='pt')
                data.append((neg_input['input_ids'], neg_input['attention_mask'], 0))
    return data, word_sense_dict


class WSDDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    input_ids = [item[0] for item in batch]
    attention_masks = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)

    return input_ids, attention_masks, labels

class BertWSDModel(nn.Module):
    def __init__(self, bert_model):
        super(BertWSDModel, self).__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]   # Pick the first element(CLS label) from each sequence
        logits = self.linear(cls_output)
        return logits

Mounted at /content/drive
22604 7535 7535


In [2]:
def predict(model, dataloader, tokenizer, device, print_every=10):
    model.eval()
    total_batches = len(dataloader)
    correct = 0
    total = 0
    list_1 = []
    list_2 = []
    with torch.no_grad():
        for i, (input_ids, attention_masks, labels) in enumerate(dataloader):
            batch_size, _, seq_length = input_ids.size()
            input_ids = input_ids.view(batch_size, seq_length)
            attention_masks = attention_masks.view(batch_size, seq_length)
            labels = labels.view(-1)

            input_ids = input_ids.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_masks)
            predictions = torch.argmax(outputs, dim=1)

            for j in range(len(predictions)):
                if labels[j] != -100:  # Assuming -100 is used for padding/ignore index
                    pred_choice = predictions[j].item()
                    true_choice = labels[j].item()

                    if pred_choice != true_choice:

                        # Print the input_ids, actual tokens, predicted label, and true label
                        input_id = input_ids[j].cpu().numpy()
                        tokens = tokenizer.convert_ids_to_tokens(input_id, skip_special_tokens=False)

                        # Calculate the length of input_ids without padding
                        padding_token_id = tokenizer.pad_token_id
                        length_without_padding = (input_ids[j] != padding_token_id).sum().item()

                        # Find the positions of the [SEP] tokens
                        sep_positions = [index for index, token in enumerate(tokens) if token == '[SEP]']

                        if len(sep_positions) >= 2:
                            sentence_tokens = tokens[:sep_positions[-1]]
                            word_tokens = tokens[sep_positions[0] + 1:sep_positions[1]]
                            if len(sep_positions) >= 3:
                                word_sense_tokens = tokens[sep_positions[1] + 1:sep_positions[2]]
                            else:
                                word_sense_tokens = tokens[sep_positions[1] + 1:]

                            sentence = tokenizer.convert_tokens_to_string(sentence_tokens)
                            word = tokenizer.convert_tokens_to_string(word_tokens)
                            word_sense = tokenizer.convert_tokens_to_string(word_sense_tokens)

                            print(f"\n* Context: {sentence}")
                            print(f"* Abbreviation: {word}")
                            print(f"* Word Sense: {word_sense}")
                            list_1.append((tokens, length_without_padding, pred_choice, true_choice, word, word_sense, sentence))
                        else:
                            print(f"* Tokens: {tokens}")
                            list_2.append((tokens, length_without_padding, pred_choice, true_choice))

                        print(f"* Length of context: {length_without_padding}")
                        print(f"* Predicted Label: {pred_choice}")
                        print(f"* True Label: {true_choice}")
                    else:
                        correct += 1
                    total += 1

            if (i + 1) % print_every == 0:
                print(f"Processed {i + 1}/{total_batches} batches.")
                current_accuracy = (correct / total) if total > 0 else 0
    accuracy = correct / total
    return accuracy, correct, total, list_1, list_2

In [3]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
    'max_length': 512,
    'batch_size': 100,
}
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load model
bert_model = BertModel.from_pretrained("bert-base-uncased")
model = BertWSDModel(bert_model).to(device)
MODEL_PATH = '/content/drive/MyDrive/NLP_WSD/model_weights/data2_stratified.pth'
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))

# Load data
with open('/content/drive/MyDrive/NLP_WSD/train_word_sense_dict', 'r') as json_file:
    train_word_sense_dict = json.load(json_file)
loaded_dict = {k: set(v) for k, v in train_word_sense_dict.items()}
test_data, _ = read_corpus(test_lines, tokenizer, max_length=params['max_length'], word_sense_dict=loaded_dict)
test_dataset = WSDDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [4]:
# Predict
model.eval()
test_dataset = WSDDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False, collate_fn=collate_fn)
accuracy, correct, total, list_1, list_2 = predict(model, test_loader, tokenizer, device, print_every=1)
print (accuracy, correct, total)
with open('/content/drive/MyDrive/NLP_WSD/stratified_list1.json', 'w') as file:
    json.dump(list_1, file, indent=4)
with open('/content/drive/MyDrive/NLP_WSD/stratified_list2.json', 'w') as file:
    json.dump(list_2, file, indent=4)


* Context: [CLS] abdominal examination shows nontender , no distention , increased bowel sounds . no vascular bruit . he has no peripheral edema . the ekg monitoring after dc cardioversion showed sinus rhythm . only 2 brief episodes , last one of 2 seconds of brief atrial fibrillation . the patient otherwise is stable , to be discharged home . [SEP] dc [SEP] discontinue
* Abbreviation: dc
* Word Sense: discontinue
* Length of context: 81
* Predicted Label: 0
* True Label: 1

* Context: [CLS] abdominal examination shows nontender , no distention , increased bowel sounds . no vascular bruit . he has no peripheral edema . the ekg monitoring after dc cardioversion showed sinus rhythm . only 2 brief episodes , last one of 2 seconds of brief atrial fibrillation . the patient otherwise is stable , to be discharged home . [SEP] dc [SEP] direct current
* Abbreviation: dc
* Word Sense: direct current
* Length of context: 79
* Predicted Label: 1
* True Label: 0
Processed 1/389 batches.

* Contex