In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
repository = 'evaluating_factuality_word_definitions'

%cd /content/drive/My Drive/{repository}

In [None]:
!pip install datasets
!pip install einops
!pip install rank_bm25

In [3]:
import gc
import torch
from models.claim_verification_model import ClaimVerificationModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
from config import DB_URL
from datasets import Dataset
from dataset.def_dataset import DefinitionDataset
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from datetime import datetime
from torch import optim
from torch.nn import BCELoss
from sklearn.metrics import classification_report

# Evaluation

In [40]:
def evaluate(ev_model, dataloader, loss_function):
    gt_labels = []
    pr_labels = []
    all_loss = []

    for batch in tqdm(dataloader):
        ev_model.eval()
        model_input = batch['model_input']
        with torch.no_grad():
            logits = ev_model(input_ids=model_input['input_ids'], 
                              attention_mask=model_input['attention_mask'])['logits']
            predicted = torch.softmax(logits, dim=-1)
            predicted[:, 1] += predicted[:, 2]
            predicted = predicted[:, :2]
            labels_one_hot = torch.zeros_like(predicted)
            labels_one_hot.scatter_(1, batch['labels'].unsqueeze(1), 1)
            loss = loss_function(predicted, labels_one_hot)
            predicted = torch.argmax(logits, dim=-1)
        gt_labels.extend(batch['labels'].tolist())
        pr_labels.extend(predicted.tolist())
        all_loss.append(loss)

    loss = sum(all_loss) / len(all_loss)
    return loss.item(), classification_report(gt_labels, pr_labels)

# Training

In [2]:
torch.cuda.empty_cache()
gc.collect()

NameError: name 'torch' is not defined

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

verification_model = ClaimVerificationModel(model).to(device)

In [5]:
dataset_query = """
select distinct dd.id, dd.claim, dd.label, docs.document_id, docs.text,
       docs.lines, se.evidence_lines as evidence_lines
from def_dataset dd
    join selected_evidence se on dd.id = se.claim_id
    join documents docs on docs.document_id = dd.evidence_wiki_url
where set_type='{set_type}'
"""

train_dataset_raw = Dataset.from_sql(dataset_query.format(set_type='train'), con=DB_URL)
dev_dataset_raw = Dataset.from_sql(dataset_query.format(set_type='dev'), con=DB_URL)

train_dataset = DefinitionDataset(train_dataset_raw, tokenizer, mode='train', model='claim_verification')
train_dataloader = DataLoader(train_dataset, shuffle=True,
                              collate_fn=train_dataset.collate_fn,
                              batch_size=16)
dev_dataset = DefinitionDataset(dev_dataset_raw, tokenizer, mode='train', model='claim_verification')
dev_dataloader = DataLoader(dev_dataset, shuffle=True,
                            collate_fn=dev_dataset.collate_fn,
                            batch_size=16)

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/119 [00:00<?, ? examples/s]

Filter:   0%|          | 0/12 [00:00<?, ? examples/s]

## Gradient Accumulation

### BCE-Loss

As we are combining REFUTES and NOT ENOUGH INFO into one Label 'NOT VERIFIABLE WRT KNOWLEDGE BASE' we can use Binary Cross Entropy Loss

In [24]:
optimizer = optim.AdamW(verification_model.parameters(), lr=1e-7)
criterion = BCELoss()

timestamp = datetime.now().strftime("%m-%d_%H-%M")

num_epochs = 10 
patience = 3   
gradient_accumulation = 64
trace_train = []
trace_val = []

verification_model.zero_grad()
use_amp = True
scaler = GradScaler(enabled=use_amp, init_scale=1)

#checkpoint = torch.load(f'training_05-15_15-55')
#optimizer.load_state_dict(checkpoint['optimizer'])
#scaler.load_state_dict(checkpoint['scaler'])

best_loss = np.inf
for epoch in range(num_epochs):
    bar_desc = "Epoch %d of %d | Iteration" % (epoch + 1, num_epochs)
    train_iterator = tqdm(train_dataloader, desc=bar_desc)

    train_loss = 0
    print('Train ...')
    for step, batch in enumerate(train_iterator):
        verification_model.train()
        model_input = batch["model_input"]

        with autocast():
            logits = verification_model(input_ids=model_input['input_ids'], 
                                attention_mask=model_input['attention_mask'])['logits']
            predicted = torch.softmax(logits, dim=-1)
            predicted[:, 1] += predicted[:, 2]
            predicted = predicted[:, :2]
            labels_one_hot = torch.zeros_like(predicted)
            labels_one_hot.scatter_(1, batch['labels'].unsqueeze(1), 1)
            loss = criterion(predicted, labels_one_hot)
            train_loss += loss.detach().item()
            loss = (loss / gradient_accumulation)
        scaler.scale(loss).backward()

        if (step + 1) % gradient_accumulation == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-1) #1e-2

            scaler.step(optimizer)
            scaler.update()
            optimizer.step()
            optimizer.zero_grad()

    trace_train.append(train_loss / len(train_dataloader))
    # validation
    with torch.no_grad():
        val_loss, report = evaluate(verification_model, dev_dataloader, criterion)
        trace_val.append(val_loss)
        print(
            f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {(train_loss / len(train_dataloader)):.4f}, Validation Loss: {val_loss:.4f}')
        print(report)

        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = epoch
            best_state = {key: value.cpu() for key, value in
                          verification_model.state_dict().items()}
            verification_model.save(f'verification_model_intermediate_{timestamp}')
            torch.save({
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict()}, f'training_{timestamp}')
        else:
            if epoch >= best_epoch + patience:
                break

        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = epoch
            best_state.clear() 
            for key, value in verification_model.state_dict().items():
                best_state[key].copy_(value.cpu())
            verification_model.save(f'selection_model_intermediate_{timestamp}')
        else:
            if epoch >= best_epoch + patience:
                break

verification_model.load_state_dict(best_state)
verification_model.save(f'selection_model_{timestamp}')
torch.save({'optimizer': optimizer.state_dict(),
            'scaler': scaler.state_dict()}, f'training_{timestamp}')

plt.plot(trace_train, label='train')
plt.plot(trace_val, label='validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

Epoch 1 of 10 | Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Train ...


Epoch 1 of 10 | Iteration:   0%|          | 0/1 [01:30<?, ?it/s]


KeyboardInterrupt: 

In [41]:
val_loss, report = evaluate(verification_model, dev_dataloader, criterion)
print(val_loss)
print(report)

100%|██████████| 1/1 [00:00<00:00,  1.28it/s]

0.08677899837493896
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         1
           1       1.00      1.00      1.00         2

    accuracy                           1.00         3
   macro avg       1.00      1.00      1.00         3
weighted avg       1.00      1.00      1.00         3






In [None]:
test_dataset_raw = Dataset.from_sql(dataset_query.format(set_type='test'), con=DB_URL)
test_dataset = DefinitionDataset(test_dataset_raw, tokenizer, mode='train', model='evidence_selection')
test_dataloader = DataLoader(test_dataset, shuffle=True, collate_fn=test_dataset.collate_fn, batch_size=10)

test_loss, report = evaluate(model, test_dataloader)
print(test_loss)
print(report)