In [None]:
import nltk
from nltk.corpus import semcor
from nltk.stem import WordNetLemmatizer
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
# from transformers import BertTokenizer
import pandas as pd 


In [None]:
# train_data = pd.read_csv('./SemCor/semcor_data.csv')
train_data = pd.read_csv('../semcor.csv')

In [None]:
test_data = pd.read_csv('./all.csv')

In [None]:
from tqdm import tqdm
def getNewData(data):
    new_data = pd.DataFrame(columns=['sentence','target_word', 'sense', 'gloss'])

    for i in tqdm(range(0,len(data))):
        sentence = data.iloc[i]['sentence']
        idx1 = sentence.find('[TGT]')
        idx2 = sentence.find('[TGT]', idx1+1)
        target_word = sentence[idx1+6:idx2-1]
        sentence = sentence.replace('[TGT]', '')
        sense_keys = data.iloc[i]['sense_keys']
        glosses = data.iloc[i]['glosses']
        target = data.iloc[i]['target']
        sense_keys = sense_keys.strip('[]')
        sense_keys = sense_keys.split(',')
        target = target.strip('[]')
        target = target.split(',')
        glosses = glosses.strip('[]')
        glosses = glosses.split(',')
        # for every target value add the correspodinign sense key in a new column and also a new column for the gloss
        for j in range(0,len(target)):
            tgt = int(target[j])
            new_row = {'sentence': sentence, 'sense': sense_keys[tgt], 'gloss': glosses[tgt], 'target_word': target_word}
            new_data = pd.concat([new_data, pd.DataFrame(new_row, index=[0])], ignore_index=True)
            new_data['sense'] = new_data['sense'].str.replace('"', '')
            new_data['sense'] = new_data['sense'].str.replace("'", '')
    return new_data




In [None]:
train_data = getNewData(train_data[:30000])

In [None]:
test_data = getNewData(test_data)


In [None]:
train_data.head(10)

In [None]:
target_word_idx = {}
idx_to_target = {}
sense_labels = []
lemma_2_sense = {}
for i in range(0,len(train_data)):
    sense_label = train_data.iloc[i]['sense']
    sense_label = sense_label.replace(' ','')
    lemma, pos, wnsn,wnsn2 = sense_label.split('%')[0], int(sense_label.split(
        '%')[1].split(':')[0]), sense_label.split('%')[1].split(':')[1],sense_label.split('%')[1].split(':')[2]
    new_label = lemma + '%' + str(pos) + '%' + wnsn + '%' + wnsn2
    if lemma not in lemma_2_sense:
        lemma_2_sense[lemma] = []
        target_word_idx[lemma] = len(target_word_idx)
        idx_to_target[len(idx_to_target)] = lemma
    if sense_label not in lemma_2_sense[lemma]:
        lemma_2_sense[lemma].append(sense_label)
    # sense_labels.append(new_label)
target_word_idx['<unk>'] = len(target_word_idx)
idx_to_target[len(idx_to_target)] = '<unk>'
lemma_2_sense['<unk>'] = ['<unk>']


In [None]:
lemma_2_sense['long']

In [None]:
new_data = pd.read_csv('./semcor_lstm.csv')

In [None]:
# find the max length of the sentence 

class preProcessDataset():
    def __init__(self,data,min_freq):
        self.data = data
        self.min_freq = min_freq
        self.word2idx = {}
        self.idx2word = {}
        self.vocab = []
        self.vocab_sense = []
        self.sense2idx = {}
        self.idx2sense = {}
        self.max_len = 0
        self.wordnet_lemmatizer = WordNetLemmatizer()

        self.lemma_2_sense = {}
        self.wordFreq = {}
        self.target2idx = target_word_idx
        self.word2idx['<pad>'] = len(self.word2idx)
        self.idx2word[len(self.idx2word)] = '<pad>'
        self.vocab.append('<pad>')
        self.vocab.append('<unk>')
        self.word2idx['<unk>'] = len(self.word2idx)
        self.idx2word[len(self.idx2word)] = '<unk>'
        self.vocab_sense.append('<unk>')
        self.sense2idx['<unk>'] = len(self.sense2idx)
        self.idx2sense[len(self.idx2sense)] = '<unk>'
        


        data = self.data
        for i in tqdm(range(len(data))):
            sentence = data.iloc[i]['sentence']
            target_word = data.iloc[i]['sense'].split('%')[0]
            target_word = target_word.lower()
            target_word = target_word.replace(' ','')
            sense_keys = data.iloc[i]['sense']
            sense_keys = sense_keys.replace(' ','')

            sentence = sentence.split()
            # count freq of words
            for word in sentence:
                word = word.lower()
                if word not in self.wordFreq:
                    self.wordFreq[word] = 0
                self.wordFreq[word] += 1


            for word in sentence:
                word = word.lower()
                # punctuation marks
                if word in ['.',',','?','!',';',':','(',')','[',']','{','}',"'",'"']:
                    continue
                if self.wordFreq[word] < self.min_freq:
                    word = '<unk>'
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.word2idx)
                    self.idx2word[len(self.idx2word)] = word
                    self.vocab.append(word)
            if len(sentence) > self.max_len:
                self.max_len = len(sentence)
            if sense_keys not in self.sense2idx:
                self.sense2idx[sense_keys] = len(self.sense2idx)
                self.idx2sense[len(self.idx2sense)] = sense_keys
                self.vocab_sense.append(sense_keys)
                

class getDataset(Dataset):
    def __init__(self, data, word2idx, sense2idx, max_len, target2idx,idx2word,wordFreq,vocab):
        self.data = data
        self.word2idx = word2idx
        self.sense2idx = sense2idx
        self.max_len = max_len
        self.target2idx = target2idx
        self.idx2word = idx2word
        self.wordFreq = wordFreq
        self.vocab = vocab
        self.input_data = []
        self.sense_data = []
        self.target2word = []


        for i in tqdm(range(len(data))):
            sentence = data.iloc[i]['sentence']
            sense_keys = data.iloc[i]['sense']
            sense_keys = sense_keys.replace(' ','')
            target_word = sense_keys.split('%')[0]
            target_word = target_word.lower()

            target_word = target_word.replace(' ','')
            sense_keys = sense_keys.replace(' ','')
            sentence = sentence.split()
            sentence_idx = []
            sense_idx = []
            for word in sentence:
                word = word.lower()
                # punctuation marks
                if word in ['.',',','?','!',';',':','(',')','[',']','{','}',"'",'"']:
                    continue
                if word not in self.word2idx:
                    word = '<unk>'
                sentence_idx.append(self.word2idx[word])
            while len(sentence_idx) < self.max_len:
                sentence_idx.append(self.word2idx['<pad>'])
            self.input_data.append(sentence_idx)
            # sense_idx.append(self.sense2idx[sense_keys])
            if sense_keys not in self.sense2idx:
                sense_keys = '<unk>'
            self.sense_data.append(self.sense2idx[sense_keys])
            if target_word not in self.target2idx:
                target_word = '<unk>'
            self.target2word.append(self.target2idx[target_word])
        
        

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return torch.tensor(self.input_data[idx]),torch.tensor(self.sense_data[idx]),torch.tensor(self.target2word[idx])




In [None]:
preProcessDataset = preProcessDataset(train_data,2)


In [None]:
trainData = getDataset(train_data,preProcessDataset.word2idx,preProcessDataset.sense2idx,preProcessDataset.max_len,preProcessDataset.target2idx,preProcessDataset.idx2word,preProcessDataset.wordFreq,preProcessDataset.vocab)

In [None]:
testData = getDataset(test_data, preProcessDataset.word2idx, preProcessDataset.sense2idx, preProcessDataset.max_len,
                      preProcessDataset.target2idx, preProcessDataset.idx2word, preProcessDataset.wordFreq, preProcessDataset.vocab)


In [None]:
from sklearn.model_selection import train_test_split

train_dataset, valid_dataset = train_test_split(trainData, test_size=0.2, random_state=42)

In [None]:
# dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=True)


In [None]:
test_dataloader = DataLoader(testData, batch_size=16, shuffle=True)

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class biLSTMModel(nn.Module):
    def __init__(self,input_size,hidden_size,sense_vocab,embedding_dim,dataset):
        super(biLSTMModel,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.sense_vocab = sense_vocab
        self.embedding_dim = embedding_dim
        self.dataset = dataset
        self.embedding = nn.Embedding(self.input_size,self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim,self.hidden_size,bidirectional=True)
        self.linear = nn.Linear(self.hidden_size*2,len(self.sense_vocab))
        self.sense2idx = self.dataset.sense2idx
        self.idx2word = idx_to_target

    def forward(self,x,target_word):
        x = self.embedding(x)
        x = x.permute(1,0,2)
        output,(hidden,cell) = self.lstm(x)
        hidden = torch.cat((hidden[-2,:,:],hidden[-1,:,:]),dim=1)
        out = self.linear(hidden)
        for i,target_wo in enumerate(target_word):
            target_wo = idx_to_target[target_wo.item()]
            target_word_sense = lemma_2_sense[target_wo]
            target_word_sense_idx = [self.sense2idx[sense] for sense in target_word_sense]
            out[i,target_word_sense_idx] = F.softmax(out[i,target_word_sense_idx],dim=0)

        return out
    

In [None]:

from torch.nn import CrossEntropyLoss
model = biLSTMModel(len(preProcessDataset.word2idx),128,preProcessDataset.vocab_sense,300,preProcessDataset)
model = model.cuda()
criterion = CrossEntropyLoss(ignore_index=preProcessDataset.word2idx['<pad>'])
optimizer = optim.Adam(model.parameters(),lr=0.001)
print(model)

In [None]:
# print(len(dataloader))

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    total_correct = 0
    total_loss = 0
    model.train()
    for i,(sentence,sense,target_word) in enumerate(train_dataloader):
        sentence = sentence.cuda()
        sense = sense.cuda()
        target_word = target_word.cuda()
        optimizer.zero_grad()
        output = model(sentence,target_word)
        # print(output)
        loss = criterion(output,sense)
        pred_sense = torch.argmax(output,dim=1)
        correct = torch.sum(pred_sense == sense)
        total_correct += correct.item()
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    print('Epoch : {}/{} | Loss : {:.4f} | Accuracy : {:.4f}'.format(epoch+1,num_epochs,total_loss/len(train_dataloader),total_correct/len(train_dataset)))

    # validation on test dataset 
    total_correct = 0
    total_loss = 0
    with torch.no_grad():
        for i,(sentence,sense,target_word) in enumerate(valid_dataloader):
            sentence = sentence.cuda()
            sense = sense.cuda()
            target_word = target_word.cuda()
            output = model(sentence,target_word)
            loss = criterion(output,sense)
            pred_sense = torch.argmax(output,dim=1)
            correct = torch.sum(pred_sense == sense)
            total_correct += correct.item()
            total_loss += loss.item()
        print('Epoch : {}/{} | Validation Loss : {:.4f} | Validation Accuracy : {:.4f}'.format(epoch+1,num_epochs,total_loss/len(valid_dataloader),total_correct/len(valid_dataset)))

torch.save(model.state_dict(),'./ckpts/biLSTM_model.pth')



In [None]:
total_correct = 0
total_loss = 0
with torch.no_grad():
    for i, (sentence, sense, target_word) in enumerate(test_dataloader):
        sentence = sentence.cuda()
        sense = sense.cuda()
        target_word = target_word.cuda()
        output = model(sentence, target_word)
        loss = criterion(output, sense)
        pred_sense = torch.argmax(output, dim=1)
        correct = torch.sum(pred_sense == sense)
        total_correct += correct.item()
        total_loss += loss.item()
    print('| Testing Loss : {:.4f} | Testing Accuracy : {:.4f}'.format(
        total_loss/len(test_dataloader), total_correct/len(testData)))
