In [None]:
!pip install datasets transformers

In [None]:
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel
from datasets import load_dataset
import torch
import torch.nn as nn
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch.nn.utils.rnn import pad_sequence
from itertools import combinations
import random
from torch.optim import Adam
import os
from torch import tensor
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
class OntoNotesDataset(Dataset):
    def __init__(self, split='train', tokenizer_name='bert-base-uncased', max_length=512):
        self.dataset = load_dataset('conll2012_ontonotesv5', 'english_v12', split=split)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
        self.max_length = max_length
        self.pos_tags, self.pairs = self.prepare_pairs()

    def prepare_pairs(self):
        pairs = []
        all_pos_tags = []
        total_coreference_spans = 0
        entity_count = 0
        for item in self.dataset:
            document = item['sentences']
            coreference_spans = {}
            coreference_pos_tags_spans = {}
            sentences_pos_tags = {}
            sentences = {}
            sentence_id = 1
            for sentence in document:
                words = sentence['words']
                coref_spans = sentence['coref_spans']
                pos = sentence['pos_tags']
                for span in coref_spans:
                    entity_id, start, end = span
                    if entity_id not in coreference_spans:
                        coreference_spans[entity_id] = []
                        coreference_pos_tags_spans[entity_id] = []
                    coreference_spans[entity_id].append([sentence_id, words[start:end+1]])
                    coreference_pos_tags_spans[entity_id].append([sentence_id, pos[start:end+1]])
                sentences_pos_tags[sentence_id] = pos
                sentences[sentence_id] = words
                sentence_id += 1

            entity_count += len(coreference_spans.keys())
            for entity_id in coreference_spans.keys():
                spans_of_entity = coreference_spans[entity_id]
                pos_tags_of_entity = coreference_pos_tags_spans[entity_id]

                n = len(spans_of_entity)
                if n % 2 != 0:
                    n -= 1
                for i in range(0, n, 2):
                    span1 = spans_of_entity[i]
                    span2 = spans_of_entity[i + 1]
                    sentence1 = sentences[span1[0]]
                    words_sentence_1 = span1[1]
                    sentence2 = sentences[span2[0]]
                    words_sentence_2 = span2[1]
                    pos_span1 = pos_tags_of_entity[i]
                    pos_span2 = pos_tags_of_entity[i + 1]
                    pos_sentence1 = sentences_pos_tags[span1[0]]
                    pos_words_sentence1 = pos_span1[1]
                    pos_sentence2 = sentences_pos_tags[span2[0]]
                    pos_words_sentence2 = pos_span2[1]
                    pairs.append([sentence1, words_sentence_1, sentence2, words_sentence_2, 1])
                    all_pos_tags.append([pos_sentence1, pos_words_sentence1, pos_sentence2, pos_words_sentence2, 1])
                    total_coreference_spans += 1

                false_entities = [eid for eid in coreference_spans.keys() if eid != entity_id]
                if (len(false_entities) > 0):
                  for i in range(0, n, 2):
                      random_entity = random.choice(false_entities)
                      random_span_index_of_entity = random.randint(0, len(coreference_spans[random_entity])-1)
                      span1 = spans_of_entity[i]
                      span2 = coreference_spans[random_entity][random_span_index_of_entity]
                      sentence1 = sentences[span1[0]]
                      words_sentence_1 = span1[1]
                      sentence2 = sentences[span2[0]]
                      words_sentence_2 = span2[1]
                      pos_span1 = pos_tags_of_entity[i]
                      pos_span2 = coreference_pos_tags_spans[random_entity][random_span_index_of_entity]
                      pos_sentence1 = sentences_pos_tags[span1[0]]
                      pos_words_sentence1 = pos_span1[1]
                      pos_sentence2 = sentences_pos_tags[span2[0]]
                      pos_words_sentence2 = pos_span2[1]
                      pairs.append([sentence1, words_sentence_1, sentence2, words_sentence_2, 0])
                      all_pos_tags.append([pos_sentence1, pos_words_sentence1, pos_sentence2, pos_words_sentence2, 0])

        return all_pos_tags, pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        two_sentence = self.pairs[idx]
        sentence_one = two_sentence[0]
        word_one = two_sentence[1]
        sentence_two = two_sentence[2]
        word_two = two_sentence[3]
        label = two_sentence[4]
        sentence_one_text = ' '.join(sentence_one)
        sentence_two_text = ' '.join(sentence_two)
        word_one_text = ' '.join(word_one)
        word_two_text = ' '.join(word_two)

        two_pos_tags = self.pos_tags[idx]
        sentence_one_pos_tags = two_pos_tags[0]
        word_one_pos_tags = two_pos_tags[1]
        sentence_two_pos_tags = two_pos_tags[2]
        word_two_pos_tags = two_pos_tags[3]

        input_sequence = f"[CLS] {sentence_one_text} [SEP] {sentence_two_text} [SEP] {word_one_text} [SEP] {word_two_text} [SEP]"
        pos_tag_sequence = [0] + sentence_one_pos_tags + [0] + sentence_two_pos_tags + [0] + word_one_pos_tags + [0] + word_two_pos_tags + [0]
        pos_tag_sequence.extend([0] * (512 - len(pos_tag_sequence)))
        tokenized_inputs = self.tokenizer(input_sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        return {
            'input_ids': tokenized_inputs['input_ids'].squeeze(0),
            'attention_mask': tokenized_inputs['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long),
            'pos_tags': torch.tensor(pos_tag_sequence, dtype=torch.long),
        }


def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch])
    pos_tags = [item['pos_tags'] for item in batch]

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    pos_tags_padded = pad_sequence(pos_tags, batch_first=True, padding_value=0)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'label': labels,
        'pos_tags': pos_tags_padded,
    }


In [None]:
class OntoNotesDatasetTest(Dataset):
    def __init__(self, split='test', tokenizer_name='bert-base-uncased', max_length=512):
        self.dataset = load_dataset('conll2012_ontonotesv5', 'english_v12', split=split)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
        self.max_length = max_length
        self.pos_tags, self.pairs, self.combination_lengths = self.prepare_pairs()

    def prepare_pairs(self):

        pairs = []
        all_pos_tags = []
        entity_count = 0
        combination_lengths = []
        for item in self.dataset:
            document = item['sentences']
            coreference_spans = {}
            coreference_pos_tags_spans = {}
            sentences_pos_tags = {}
            sentences = {}
            sentence_id = 1
            for sentence in document:
                words = sentence['words']
                coref_spans = sentence['coref_spans']
                pos = sentence['pos_tags']
                for span in coref_spans:
                    entity_id, start, end = span
                    if entity_id not in coreference_spans:
                        coreference_spans[entity_id] = []
                        coreference_pos_tags_spans[entity_id] = []
                    coreference_spans[entity_id].append([sentence_id, words[start:end+1]])
                    coreference_pos_tags_spans[entity_id].append([sentence_id, pos[start:end+1]])
                sentences_pos_tags[sentence_id] = pos
                sentences[sentence_id] = words
                sentence_id += 1

            entity_count += len(coreference_spans.keys())
            for entity_id in coreference_spans.keys():
                spans_of_entity = coreference_spans[entity_id]
                pos_tags_of_entity = coreference_pos_tags_spans[entity_id]

                n = len(spans_of_entity)
                if n % 2 != 0:
                    n -= 1

                for i in range(0, n, 2):
                    span1 = spans_of_entity[i]
                    span2 = spans_of_entity[i + 1]
                    sentence1 = sentences[span1[0]]
                    words_sentence_1 = span1[1]
                    sentence2 = sentences[span2[0]]
                    words_sentence_2 = span2[1]

                    pos_span1 = pos_tags_of_entity[i]
                    pos_span2 = pos_tags_of_entity[i + 1]
                    pos_sentence1 = sentences_pos_tags[span1[0]]
                    pos_words_sentence1 = pos_span1[1]
                    pos_sentence2 = sentences_pos_tags[span2[0]]
                    pos_words_sentence2 = pos_span2[1]
                    combination_len = 0
                    for word in sentence2:
                      if word not in words_sentence_2:
                        word_index = sentence2.index(word)
                        pairs.append([sentence1, words_sentence_1, sentence2, [word], 0])
                        all_pos_tags.append([pos_sentence1, pos_words_sentence1, pos_sentence2, [pos_sentence2[word_index]], 0])
                        combination_len+=1
                    pairs.append([sentence1, words_sentence_1, sentence2, words_sentence_2, 1])
                    all_pos_tags.append([pos_sentence1, pos_words_sentence1, pos_sentence2, pos_words_sentence2, 1])
                    combination_len+=1
                    combination_lengths.append(combination_len)
        return all_pos_tags, pairs, combination_lengths

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        two_sentence = self.pairs[idx]
        sentence_one = two_sentence[0]
        word_one = two_sentence[1]
        sentence_two = two_sentence[2]
        word_two = two_sentence[3]
        label = two_sentence[4]
        sentence_one_text = ' '.join(sentence_one)
        sentence_two_text = ' '.join(sentence_two)
        word_one_text = ' '.join(word_one)
        word_two_text = ' '.join(word_two)

        two_pos_tags = self.pos_tags[idx]
        sentence_one_pos_tags = two_pos_tags[0]
        word_one_pos_tags = two_pos_tags[1]
        sentence_two_pos_tags = two_pos_tags[2]
        word_two_pos_tags = two_pos_tags[3]

        input_sequence = f"[CLS] {sentence_one_text} [SEP] {sentence_two_text} [SEP] {word_one_text} [SEP] {word_two_text} [SEP]"
        pos_tag_sequence = [0] + sentence_one_pos_tags + [0] + sentence_two_pos_tags + [0] + word_one_pos_tags + [0] + word_two_pos_tags + [0]
        pos_tag_sequence.extend([0] * (512 - len(pos_tag_sequence)))
        tokenized_inputs = self.tokenizer(input_sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        return {
            'input_ids': tokenized_inputs['input_ids'].squeeze(0),
            'attention_mask': tokenized_inputs['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long),
            'pos_tags': torch.tensor(pos_tag_sequence, dtype=torch.long),
        }


def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch])
    pos_tags = [item['pos_tags'] for item in batch]

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    pos_tags_padded = pad_sequence(pos_tags, batch_first=True, padding_value=0)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'label': labels,
        'pos_tags': pos_tags_padded,
    }


In [None]:
class CorefResolver(BertPreTrainedModel):
    def __init__(self, config, pos_tag_dim = 50, num_pos_tags = 51):
        super().__init__(config)
        self.bert = BertModel(config)
        self.pos_tag_embeddings = nn.Embedding(num_pos_tags, pos_tag_dim, padding_idx=0)
        self.classifier = nn.Sequential(
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size + pos_tag_dim, 256),
            nn.ReLU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, pos_tags, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        pos_embeddings = self.pos_tag_embeddings(pos_tags)
        combined = torch.cat((sequence_output, pos_embeddings), dim=-1)
        combined = torch.mean(combined, dim=1)
        logits = self.classifier(combined).squeeze()
        loss = None
        if labels is not None:
            loss_fct = nn.BCELoss()
            logits = logits.view(-1)
            loss = loss_fct(logits, labels.float())
        return (loss, logits) if loss is not None else logits

In [None]:
dataset = OntoNotesDataset(split='train', tokenizer_name='bert-base-uncased', max_length=512)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
model = CorefResolver.from_pretrained('bert-base-uncased')
optimizer = Adam(model.parameters(), lr=1e-5)


def train(model, data_loader, optimizer, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}", leave=False)

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            pos_tags = batch['pos_tags'].to(device)
            optimizer.zero_grad()
            try:
                loss, _ = model(input_ids, attention_mask, pos_tags, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                progress_bar.set_postfix(loss=f"{total_loss / (progress_bar.last_print_n + 1):.4f}")
            except RuntimeError as e:
                print(f"Error during training: {str(e)}")
                continue

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch+1}: Average Loss = {avg_loss:.4f}")

        save_path = 'drive/MyDrive/COMP442FinalProject/models'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        model_save_path = os.path.join(save_path, f"conll_pos_tag_embdim50_epoch_{epoch+1}.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_save_path)
        print(f"Model saved to {model_save_path}")


train(model, data_loader, optimizer, epochs=1)

# **Span Index Prediction**

In [None]:
def evaluate_model(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    thresholds = [0.6, 0.7, 0.8, 0.9, 0.95]
    metrics = {
        "Accuracy": [],
        "Precision": [],
        "Recall": [],
        "F1 Score": []
    }

    with torch.no_grad():
        for threshold in thresholds:
            y_pred = []
            y_true = []
            a = 0
            for batch in tqdm(data_loader, desc=f"Evaluating for threshold {threshold}"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                pos_tags = batch['pos_tags'].to(device)
                labels = batch['label'].to(device)

                logits = model(input_ids, attention_mask, pos_tags)
                pred_labels = (logits > threshold).cpu().numpy()
                y_pred.append(pred_labels)
                y_true.append(labels.cpu().numpy())
                a+=1
                if a == 1000:
                  break

            accuracy = accuracy_score(y_true, y_pred)
            precision = precision_score(y_true, y_pred, zero_division=0)
            recall = recall_score(y_true, y_pred, zero_division=0)
            f1 = f1_score(y_true, y_pred, zero_division=0)

            metrics["Accuracy"].append(accuracy)
            metrics["Precision"].append(precision)
            metrics["Recall"].append(recall)
            metrics["F1 Score"].append(f1)

    plt.figure(figsize=(10, 6))
    for metric, values in metrics.items():
        print(metric, values)
        plt.plot(thresholds, values, marker='o', label=metric)

    plt.title("Evaluation Metrics Across Different Thresholds")
    plt.xlabel("Threshold")
    plt.ylabel("Metric Value")
    plt.legend()
    plt.grid(True)
    plt.show()

    return metrics



model = CorefResolver.from_pretrained('bert-base-uncased')
model_save_path = 'drive/MyDrive/COMP442FinalProject/models/conll_pos_tag_embdim50_epoch_1.pt'
checkpoint = torch.load(model_save_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

dataset_test = OntoNotesDatasetTest(split='test', tokenizer_name='bert-base-uncased', max_length=512)
combination_lengths = dataset_test.combination_lengths
loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

metrics = evaluate_model(model, loader_test)

# **Binary classification on mention pairs**

In [None]:
def evaluate_model(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    thresholds = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
    metrics = {
        "Accuracy": [],
        "Precision": [],
        "Recall": [],
        "F1 Score": []
    }

    with torch.no_grad():
      for threshold in thresholds:
        y_pred = []
        y_true = []
        a = 0
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            pos_tags = batch['pos_tags'].to(device)

            logits = model(input_ids, attention_mask, pos_tags)

            probabilities = torch.sigmoid(logits)
            predictions = (probabilities > threshold).float()
            y_pred.append(predictions.cpu().numpy())
            y_true.append(labels.cpu().numpy())
            a+=1
            if a == 1000:
              break
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        metrics["Accuracy"].append(accuracy)
        metrics["Precision"].append(precision)
        metrics["Recall"].append(recall)
        metrics["F1 Score"].append(f1)

    plt.figure(figsize=(10, 6))
    for metric, values in metrics.items():
        plt.plot(thresholds, values, marker='o', label=metric)

    plt.title("Evaluation Metrics Across Different Thresholds")
    plt.xlabel("Threshold")
    plt.ylabel("Metric Value")
    plt.legend()
    plt.grid(True)
    plt.show()

    return metrics



model = CorefResolver.from_pretrained('bert-base-uncased')
model_save_path = 'drive/MyDrive/COMP442FinalProject/models/conll_pos_tag_embdim50_epoch_1.pt'
checkpoint = torch.load(model_save_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
dataset_test = OntoNotesDataset(split='test', tokenizer_name='bert-base-uncased', max_length=512)
loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True, collate_fn=collate_fn)

results = evaluate_model(model, loader_test)
print("Evaluation Results:", results)