In [None]:
!pip install datasets transformers

In [None]:
from transformers import pipeline, BertTokenizer, AutoModelForTokenClassification, BertModel, AdamW, BertPreTrainedModel
from tqdm.auto import tqdm
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Adam
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

In [None]:
pos_model_name = "QCRI/bert-base-multilingual-cased-pos-english"
pos_tokenizer = BertTokenizer.from_pretrained(pos_model_name)
pos_model = AutoModelForTokenClassification.from_pretrained(pos_model_name).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
pos_pipeline = pipeline("ner", model=pos_model, tokenizer=pos_tokenizer, device=0 if torch.cuda.is_available() else -1)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def prepare_data(tokenizer, entries, pos_pipeline):
    texts = []
    labels = []
    pos_tags_indices = []

    label2id = {
        "#": 7, "$": 6, "''": 5, ",": 2, "-LRB-": 17, "-RRB-": 32, ".": 4, ":": 3, "CC": 8,
        "CD": 9, "DT": 10, "EX": 11, "FW": 12, "IN": 13, "JJ": 14, "JJR": 15, "JJS": 16,
        "LS": 18, "MD": 19, "NN": 20, "NNP": 21, "NNPS": 22, "NNS": 23, "O": 0, "PDT": 24,
        "POS": 25, "PRP": 26, "PRP$": 27, "RB": 28, "RBR": 29, "RBS": 30, "RP": 31, "SYM": 33,
        "TO": 34, "UH": 35, "VB": 36, "VBD": 37, "VBG": 38, "VBN": 39, "VBP": 40, "VBZ": 41,
        "WDT": 42, "WP": 43, "WP$": 44, "WRB": 45, "``": 1
    }

    for entry in tqdm(entries, desc="Processing entries"):
        text_a = f"[CLS] {entry['Text']} [SEP] {entry['Pronoun']} [SEP] {entry['A']}"
        text_b = f"[CLS] {entry['Text']} [SEP] {entry['Pronoun']} [SEP] {entry['B']}"

        pos_result_a = pos_pipeline(text_a)
        pos_result_b = pos_pipeline(text_b)

        pos_indices_a = [label2id[tag['entity']] for tag in pos_result_a]
        pos_indices_b = [label2id[tag['entity']] for tag in pos_result_b]

        texts.extend([text_a, text_b])
        labels.extend([int(entry['A-coref']), int(entry['B-coref'])])
        pos_tags_indices.extend([pos_indices_a, pos_indices_b])

    max_length=512
    encodings = tokenizer(texts, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
    pos_tags_padded = torch.tensor([tags + [0] * (max_length - len(tags)) for tags in pos_tags_indices], dtype=torch.long)

    return GAPDataset(encodings, labels, pos_tags_padded)

class GAPDataset(Dataset):
    def __init__(self, encodings, labels, pos_tags):
        self.encodings = encodings
        self.labels = labels
        self.pos_tags = pos_tags

    def __getitem__(self, idx):
        item = {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'pos_tags': torch.tensor(self.pos_tags[idx], dtype=torch.long)
        }
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

dataset = load_dataset("gap", split='train')
dataset_test = load_dataset("gap", split='test')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

prepared_dataset = prepare_data(tokenizer, dataset, pos_pipeline)
prepared_test_dataset = prepare_data(tokenizer, dataset_test, pos_pipeline)
loader = DataLoader(prepared_dataset, batch_size=32, shuffle=True)
loader_test = DataLoader(prepared_test_dataset, batch_size=32, shuffle=False)

In [None]:
torch.save(prepared_dataset, './train.pt')
torch.save(prepared_test_dataset, './test.pt')

In [None]:
class GAPDataset(Dataset):
    def __init__(self, encodings, labels, pos_tags):
        self.encodings = encodings
        self.labels = labels
        self.pos_tags = pos_tags

    def __getitem__(self, idx):
        item = {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'pos_tags': torch.tensor(self.pos_tags[idx], dtype=torch.long)
        }
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

prepared_dataset = torch.load('./train.pt')
prepared_test_dataset = torch.load('./test.pt')
loader = DataLoader(prepared_dataset, batch_size=16, shuffle=True)
loader_test = DataLoader(prepared_test_dataset, batch_size=16, shuffle=False)

In [None]:
class CorefResolver(BertPreTrainedModel):
    def __init__(self, config, pos_tag_dim=50, num_pos_tags=46):
        super().__init__(config)
        self.bert = BertModel(config)
        self.pos_tag_embeddings = nn.Embedding(num_pos_tags, pos_tag_dim, padding_idx=0)
        self.classifier = nn.Sequential(
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size + pos_tag_dim, 256),
            nn.ReLU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, pos_tags, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        pos_embeddings = self.pos_tag_embeddings(pos_tags)
        combined = torch.cat((sequence_output, pos_embeddings), dim=-1)

        combined = torch.mean(combined, dim=1)

        logits = self.classifier(combined).squeeze()

        loss = None
        if labels is not None:
            loss_fct = nn.BCELoss()
            logits = logits.view(-1)
            loss = loss_fct(logits, labels.float())

        return (loss, logits) if loss is not None else logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CorefResolver.from_pretrained('bert-base-uncased', num_pos_tags=50, pos_tag_dim=512).to(device)
optimizer = Adam(model.parameters(), lr=1e-5)


def train(model, data_loader, optimizer, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}", leave=False)

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pos_tags = batch['pos_tags'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            try:
                loss, _ = model(input_ids, attention_mask, pos_tags, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                progress_bar.set_postfix(loss=f"{total_loss / (progress_bar.last_print_n + 1):.4f}")

            except RuntimeError as e:
                print(f"Error during training: {str(e)}")
                continue

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch+1}: Average Loss = {avg_loss:.4f}")


train(model, loader, optimizer)


In [None]:
def evaluate_model(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    y_pred = []
    y_true = []
    y_scores = []


    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pos_tags = batch['pos_tags'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask, pos_tags)

            probabilities = torch.sigmoid(logits)
            custom_threshold = 0.7
            predictions = (probabilities > custom_threshold).float()
            y_pred.extend(predictions.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            y_scores.extend(probabilities.cpu().numpy())


    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1,
    }

results = evaluate_model(model, loader_test)
print("Evaluation Results:", results)
