In [None]:
!pip install datasets transformers

In [None]:
from transformers import BertTokenizer, BertModel, BertModel, BertPreTrainedModel
from tqdm.auto import tqdm
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Adam
from torch import tensor
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

In [None]:
dataset = load_dataset("gap", split='train')
dataset_test = load_dataset("gap", split='test')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class GAPDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

def encode_data(tokenizer, texts, max_length=512):
    return tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

texts = dataset['Text']
labels = [int(a) for a in dataset['A-coref']]
encodings = encode_data(tokenizer, texts)
gap_dataset = GAPDataset(encodings, labels)
loader = DataLoader(gap_dataset, batch_size=8, shuffle=True)

texts_test = dataset_test['Text']
labels_test = [int(a) for a in dataset_test['A-coref']]
encodings_test = encode_data(tokenizer, texts_test)
gap_dataset_test = GAPDataset(encodings_test, labels_test)
loader_test = DataLoader(gap_dataset_test, batch_size=8, shuffle=True)


In [None]:
class CorefResolver(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier = nn.Sequential(
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        combined = torch.mean(sequence_output, dim=1)
        logits = self.classifier(combined).squeeze()
        loss = None
        if labels is not None:
            loss_fct = nn.BCELoss()
            logits = logits.view(-1)
            loss = loss_fct(logits, labels.float())

        return (loss, logits) if loss is not None else logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CorefResolver.from_pretrained('bert-base-uncased').to(device)
optimizer = Adam(model.parameters(), lr=1e-5)


def train(model, data_loader, optimizer, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}", leave=False)

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            try:
                loss, _ = model(input_ids, attention_mask, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                progress_bar.set_postfix(loss=f"{total_loss / (progress_bar.last_print_n + 1):.4f}")

            except RuntimeError as e:
                print(f"Error during training: {str(e)}")
                continue

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch+1}: Average Loss = {avg_loss:.4f}")


train(model, loader, optimizer)


In [None]:
def evaluate_model(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    y_pred = []
    y_true = []
    y_scores = []


    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask)

            probabilities = torch.sigmoid(logits)
            custom_threshold = 0.7
            predictions = (probabilities > custom_threshold).float()
            y_pred.extend(predictions.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            y_scores.extend(probabilities.cpu().numpy())


    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    auc = roc_auc_score(y_true, y_scores)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1,
        "Confusion Matrix": cm,
        "AUC": auc
    }

results = evaluate_model(model, loader_test)
print("Evaluation Results:", results)
