In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
repository = 'evaluating_factuality_word_definitions'

%cd /content/drive/My Drive/{repository}

In [None]:
%%capture
!pip install datasets~=2.18.0
!pip install einops~=0.8.0
!pip install rank_bm25~=0.2.2
!pip install wandb~=0.17.5

In [1]:
import gc
import torch
import random
from models.claim_verification_model import ClaimVerificationModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from losses.atomic_fact_loss import AtomicFactsLoss
from dataset.def_dataset import DefinitionDataset
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from datetime import datetime
from torch import optim
from sklearn.metrics import classification_report
import wandb

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# Evaluation

In [2]:
def evaluate(ev_model, dataloader, loss_function):
    gt_labels = []
    pr_labels = []
    all_loss = []

    for batch in tqdm(dataloader):
        ev_model.eval()
        model_input = batch['model_input']
        claim_mask = model_input.get('claim_mask')
        with torch.no_grad():
            logits = ev_model(input_ids=model_input['input_ids'], 
                              attention_mask=model_input['attention_mask'])['logits']
            predicted = torch.softmax(logits, dim=-1)

            labels = batch['labels']
            loss = loss_function(labels, predicted[:, :1], claim_mask)
            
            predicted_label = torch.argmax(predicted, dim=-1).unsqueeze(1)
            predicted_label = (predicted_label * claim_mask.unsqueeze(2)).squeeze(2)            
            predicted_label = torch.all(predicted_label == 0, dim=1)   # index 0 == SUPPORTED
            
        gt_labels.extend(batch['labels'].tolist())
        pr_labels.extend(predicted_label.tolist())
        all_loss.append(loss)

    loss = sum(all_loss) / len(all_loss)
    return loss.item(), classification_report(gt_labels, pr_labels, digits=4)

# Training

In [3]:
torch.cuda.empty_cache()
gc.collect()

75

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

verification_model = ClaimVerificationModel(model).to(device)

In [5]:
from datasets import load_dataset

def map_dataset(example, sentence_ordering):
    if example.get('selected_evidence_lines'):
        evidence_lines = example['selected_evidence_lines'].split(',')
        if sentence_ordering == 'top_last':
            evidence_lines = evidence_lines[1:] + [evidence_lines[0]]
        elif sentence_ordering == 'reverse':
            evidence_lines.reverse()
        example['evidence_lines'] = ','.join(evidence_lines)
    return example

sentence_ordering = 'keep'
dataset = load_dataset("lukasellinger/filtered_fever-claim_verification")
dataset = dataset.map(map_dataset, fn_kwargs={'sentence_ordering': sentence_ordering})
dataset = dataset.remove_columns("atomic_facts")  # try without atomic facts as splitting error prune

train_dataset = DefinitionDataset(dataset['train'], tokenizer, mode='validation', model='claim_verification')
train_dataloader = DataLoader(train_dataset, shuffle=True,
                              collate_fn=train_dataset.collate_fn,
                              batch_size=32)
dev_dataset = DefinitionDataset(dataset['dev'], tokenizer, mode='train', model='claim_verification')
dev_dataloader = DataLoader(dev_dataset, shuffle=True,
                            collate_fn=dev_dataset.collate_fn,
                            batch_size=32)

Filter:   0%|          | 0/29237 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1978 [00:00<?, ? examples/s]

In [6]:
len(train_dataloader)

5455

## Gradient Accumulation

### AtomicFacts-Loss

As we are combining REFUTES and NOT ENOUGH INFO into one Label 'NOT VERIFIABLE WRT KNOWLEDGE BASE' we can use Binary Cross Entropy Loss

In [None]:
settings = {
    "learning_rate": 1e-6,
    "dataset": "FEVER",
    "epochs": 6,
    "patience": 2,
    "gradient_accumulation": 16,
    "seed": seed,
    'sentence_ordering': sentence_ordering,
    'mode': 'including not enough info'
    }

# start a new wandb run to track this script
wandb.init(
    project="claim_verification",
    config=settings
)

In [12]:
optimizer = optim.AdamW(verification_model.parameters(), lr=settings.get('learning_rate'))
criterion = AtomicFactsLoss(pos_weight=2) #BCELoss()

timestamp = datetime.now().strftime("%m-%d_%H-%M")

num_epochs = settings.get('epochs') 
patience = settings.get('patience')   
gradient_accumulation = settings.get('gradient_accumulation')
trace_train = []
trace_val = []

verification_model.zero_grad()
use_amp = True
scaler = GradScaler(enabled=use_amp, init_scale=1)

checkpoint = {
    'model': verification_model.model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scaler': scaler.state_dict(),
    'settings': settings,
    'epoch': 0
}

torch.save(checkpoint, f'claim_verification_start_{timestamp}')
#wandb.save('claim_verification_start') # does not work in colab

#checkpoint = torch.load(f'training_05-15_15-55')
#optimizer.load_state_dict(checkpoint['optimizer'])
#scaler.load_state_dict(checkpoint['scaler'])

best_loss = np.inf
epoch = 0
for epoch in range(num_epochs):
    bar_desc = "Epoch %d of %d | Iteration" % (epoch + 1, num_epochs)
    train_iterator = tqdm(train_dataloader, desc=bar_desc)

    train_loss = 0
    print('Train ...')
    for step, batch in enumerate(train_iterator):
        verification_model.train()
        model_input = batch["model_input"]
        claim_mask = model_input.get('claim_mask')

        with autocast():
            logits = verification_model(input_ids=model_input['input_ids'], 
                                        attention_mask=model_input['attention_mask'])['logits']
            predicted = torch.softmax(logits, dim=-1)
            predicted = predicted[:, :1]
            labels = batch['labels']
            loss = criterion(labels, predicted, claim_mask)
            train_loss += loss.detach().item()
            loss = (loss / gradient_accumulation)
        scaler.scale(loss).backward()
        
        if (step + 1) % gradient_accumulation == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-1) #1e-2

            scaler.step(optimizer)
            scaler.update()
            optimizer.step()
            optimizer.zero_grad()

    trace_train.append(train_loss / len(train_dataloader))
    # validation
    with torch.no_grad():
        val_loss, report = evaluate(verification_model, dev_dataloader, criterion)
        trace_val.append(val_loss)
        print(
            f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {(train_loss / len(train_dataloader)):.4f}, Validation Loss: {val_loss:.4f}')
        print(report)
        wandb.log({"val_loss": val_loss, "train_loss": train_loss / len(train_dataloader), "report": report})

        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = epoch
            best_state = {key: value.cpu() for key, value in
                          verification_model.state_dict().items()}
            verification_model.save(f'claim_verification_model_intermediate_{timestamp}_epoch{epoch}_{sentence_ordering}')
        else:
            if epoch >= best_epoch + patience:
                break

verification_model.load_state_dict(best_state)

checkpoint = {
    'model': verification_model.model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scaler': scaler.state_dict(),
    'settings': settings,
    'epoch': epoch
}

torch.save(checkpoint, f'claim_verification_done_{timestamp}_{sentence_ordering}')
#wandb.save(f'claim_verification_done_{timestamp}')  # does not work in colab
wandb.finish()

plt.plot(trace_train, label='train')
plt.plot(trace_val, label='validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)



Train ...


Epoch 1 of 10 | Iteration:   0%|          | 0/5455 [04:09<?, ?it/s]

KeyboardInterrupt



In [56]:
criterion = AtomicFactsLoss() #BCELoss()
val_loss, report = evaluate(verification_model, train_dataloader, criterion)
print(val_loss)
print(report)

  1%|          | 55/5455 [17:18<28:19:19, 18.88s/it]


KeyboardInterrupt: 

In [None]:
test_dataset = DefinitionDataset(dataset['test'], tokenizer, mode='train', model='evidence_selection')
test_dataloader = DataLoader(test_dataset, shuffle=True, collate_fn=test_dataset.collate_fn, batch_size=10)

test_loss, report = evaluate(model, test_dataloader)
print(test_loss)
print(report)