In [None]:
!pip install transformers

In [None]:
import pandas as pd
import numpy as np
from zipfile import ZipFile
from sklearn.model_selection import train_test_split
from typing import Dict
import torch
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score
from operator import itemgetter
from sklearn.metrics import precision_score
import pickle
import time
from tqdm import tqdm

from transformers import AutoModel
from transformers import AutoTokenizer

In [None]:
import transformers
print(transformers.__version__)

In [None]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

In [None]:
# transformer_model = "SpanBERT/spanbert-large-cased"
transformer_model = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(transformer_model)

In [None]:
test=pd.read_csv('/kaggle/input/nlphw3/dev.tsv',sep='\t')
data=pd.read_csv('/kaggle/input/nlphw3/train.tsv',sep='\t',header=None)
data.columns=test.columns

In [None]:
train, val = train_test_split(data, test_size=0.2,random_state=4)
train = train.reset_index(drop=True)
val = val.reset_index(drop=True)

In [None]:
class CorefData(torch.utils.data.Dataset):
    def __init__(self, data):
        print(data.shape)
        self.lemmas = data["Text"]
        self.label = data[["A-coref", "B-coref"]].apply(
            lambda x: 0 if x["A-coref"] 
            else 1 if x["B-coref"] else 2, axis=1)
        self.pronoun_offset=data['Pronoun-offset']
        self.A_offset=data['A-offset']
        self.B_offset=data['B-offset']
        
        
        
    def __len__(self):
        return len(self.lemmas)
  
    def __getitem__(self,idx):
        return self.lemmas[idx],self.label[idx],self.pronoun_offset[idx],self.A_offset[idx],self.B_offset[idx]
    
    
    
    
    
    

In [None]:
def get_word_index(offsets, offset_list):
# Get the rows where both elements are non-zero and set them to -1
    zeros = (offsets[:, :, 0] == 0) & (offsets[:, :, 1] == 0)
    offsets[zeros] = -1
# Use boolean indexing to remove rows where both elements are zero
    word_indexes = []
    for i,offset in enumerate(offset_list):
        condition = (offsets[i,:, 0] == offset)
        word_index = torch.nonzero(condition)
        word_indexes.append(word_index)
    word_index_tensor=torch.tensor(word_indexes).squeeze(0)
    return word_index_tensor


def collate_fn(batch):
    batch_out = tokenizer(
        [sentence[0] for sentence in batch],
        return_tensors="pt",
        padding=True,
        is_split_into_words=False,
        return_offsets_mapping=True
    ) 
    offset_mapping=batch_out['offset_mapping']
    batch_out['label']=[sentence[1] for sentence in batch]
    batch_out['Pronoun_loc']=get_word_index(offset_mapping,[sentence[2] for sentence in batch])
    batch_out['A_loc']=get_word_index(offset_mapping,[sentence[3] for sentence in batch])
    batch_out['B_loc']=get_word_index(offset_mapping,[sentence[4] for sentence in batch])
    return batch_out

In [None]:
train_dataset=CorefData(train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)
val_dataset=CorefData(val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)
test_dataset=CorefData(test)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False,collate_fn=collate_fn)


In [None]:
for i in val_dataloader:
    print(i)
    break

In [None]:
class Coref(torch.nn.Module):
    def __init__(self, pre_trained_transformer_model):
        super(Coref, self).__init__()
        self.transformer_model = pre_trained_transformer_model
        for param in self.transformer_model.parameters():
            param.requires_grad = False
        self.dropout = torch.nn.Dropout(0.5)
        self.lstm = torch.nn.LSTM(input_size=self.transformer_model.config.hidden_size, 
                            hidden_size=786, 
                            batch_first=True,
                            bidirectional=True,
                            num_layers=2,
                            dropout=0.5
                                 )
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear((786 * 2)*3, 1024)
        self.fc2 = torch.nn.Linear(1024, 3)
        
    def forward(self, batch):
        # Get the last hidden state of the BERT model
        input_ids = batch['input_ids'].to(device)
        attention_mask=batch['attention_mask'].to(device)
        transformers_outputs = self.transformer_model(input_ids,attention_mask)
        embed_out = torch.stack(transformers_outputs.hidden_states[-4:], dim=0).sum(dim=0)
        
        output,_ = self.lstm(embed_out)
        out_sent = torch.mean(output, dim=1)
        
        out_pron = output[torch.arange(output.shape[0]), batch['Pronoun_loc'], :]
        out_A = output[torch.arange(output.shape[0]), batch['A_loc'], :]
        out_B = output[torch.arange(output.shape[0]), batch['B_loc'], :]
        
        
        out_pron_A = out_pron-out_A/2
        out_pron_B = out_pron-out_B/2

        total_out= torch.cat((out_sent,out_pron_A,out_pron_B),1)
        relu1 = self.relu(total_out)
        dense1 = self.fc1(self.dropout(relu1))
        relu2 = self.relu(dense1)
        preds = self.fc2(self.dropout(relu2))
        return preds

In [None]:
pre_trained_transformer_model=AutoModel.from_pretrained(transformer_model, output_hidden_states=True)
model=Coref(pre_trained_transformer_model).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train_model(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in tqdm(iterator):
        
        text = batch
        tags = torch.tensor(batch['label'])
        optimizer.zero_grad()
        predictions = model(text)
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1).type(torch.LongTensor).to(device)
        loss = criterion(predictions, tags)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch
            tags = torch.tensor(batch['label'])
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1).type(torch.LongTensor).to(device)
            

            
            loss = criterion(predictions, tags)
            
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def predict(model, iterator):
    pred=[]
    tag=[]
    model.eval()
    
    with torch.no_grad():
        for batch in iterator:

            text = batch
            tags = torch.tensor(batch['label'])
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            max_preds = predictions.argmax(dim = 1, keepdim = False)
            tags = tags.view(-1)

            pred.append(max_preds.tolist())
            tag.append(tags.tolist())
            
        
    return pred,tag

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss= train_model(model, train_dataloader, optimizer, criterion)
    valid_loss= evaluate(model, val_dataloader, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

In [None]:
pred,label=predict(model, test_dataloader)

In [None]:
pred = [item for sublist in pred for item in sublist]
label = [item for sublist in label for item in sublist]

In [None]:
from sklearn.metrics import f1_score
f1_score(pred,label,average=None)

In [None]:
torch.save(model.state_dict(), 'model-hw2.pth')

In [None]:
# config0: add all emb
# config1:add emb/2
# config2:sub emb/2 array([0.77453581, 0.79146919, 0.55045872])
# config3: consine  array([0.78947368, 0.81339713, 0.56363636])

In [None]:
# spanbert# array([0.78074866, 0.79816514, 0.53061224])
# bert # array([0.77804296, 0.78010471, 0.56074766])
# array([0.81909548, 0.78061224, 0.59322034])