# AIR - Exercise in Google Colab

## Colab Preparation

Open via google drive -> right click: open with Colab

**Get a GPU**

Toolbar -> Runtime -> Change Runtime Type -> GPU

**Mount Google Drive**

* Download data and clone your github repo to your Google Drive folder
* Use Google Drive as connection between Github and Colab (Could also use direct github access, but re-submitting credentials might be annoying)
* Commit to Github locally from the synced drive

**Keep Alive**

When training google colab tends to kick you out, This might help: https://medium.com/@shivamrawat_756/how-to-prevent-google-colab-from-disconnecting-717b88a128c0

**Get Started**

Run the following script to mount google drive and install needed python packages. Pytorch comes pre-installed.

In [None]:
#%pip install -r ../requirements.txt

In [None]:
import torch

print("Version:",torch.__version__)
print("Has GPU:",torch.cuda.is_available()) # check that 1 gpu is available
print("Random tensor:",torch.rand(10)) # check that pytorch works 

# Download Data

In [None]:
#TODO check if downloading data is desired

from pathlib import Path
import requests
import zipfile

DATA_PATH = Path("../data")
DATA_PATH.mkdir(exist_ok=True, parents=True)
DATA_ZIP = DATA_PATH / "data.zip"
DATA_URL = "https://owncloud.tuwien.ac.at/index.php/s/QA4LEtxdBokqdNx/download"

GLOVE_ZIP = DATA_PATH / "glove.42B.300d.zip"
GLOVE_URL = "http://nlp.stanford.edu/data/glove.42B.300d.zip"


if not DATA_ZIP.exists():
    r = requests.get(DATA_URL)
    DATA_ZIP.write_bytes(r.content)
    del r

    with zipfile.ZipFile(DATA_ZIP, 'r') as zip_ref:
        zip_ref.extractall(DATA_PATH)

    for f in DATA_PATH.rglob("air-exercise-2/*/*"):
        f.rename(DATA_PATH / f.name)

    for f in DATA_PATH.rglob("air-exercise-2/*"):
        if f.is_dir():
            f.rmdir()

    (DATA_PATH / "air-exercise-2").rmdir()


if not GLOVE_ZIP.exists():
    r = requests.get(GLOVE_URL)
    GLOVE_ZIP.write_bytes(r.content)
    del r

    with zipfile.ZipFile(GLOVE_ZIP, 'r') as zip_ref:
        zip_ref.extractall(DATA_PATH)

# Re Ranker

### Global Variables

In [None]:
#SET CHECKPOINT PATH
#TODO checkpoint folder 
CHECKPOINT_FOLDER_PATH = "../checkpoints"

In [None]:
#TODO check epochs
MAX_EPOCHS = 10
device = ("cuda" if torch.cuda.is_available() else "cpu")

## Imports

In [None]:
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment
from allennlp.data.dataloader import PyTorchDataLoader

params = Params({'random_seed': 42, 'numpy_seed': 42, 'pytorch_seed': 42})
prepare_environment(params)

import torch
import torch.nn as nn
import torch.optim as optim

import os

from allennlp.data.vocabulary import Vocabulary

from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

from data_loading import *
from model_knrm import *
from model_tk import *
from core_metrics import *

## config and embeddings

In [None]:
# change paths to your data directory
config = {
    "vocab_directory": "../data/allen_vocab_lower_10",
    "pre_trained_embedding": "../data/glove.42B.300d.txt",
    "train_data": "../data/triples.train.tsv",
    "validation_data": "../data/msmarco_tuples.validation.tsv",
    "test_data":"../data/msmarco_tuples.test.tsv",
    "learning_rate": 1e-3,
    "num_epochs": 10,
    "qrels_file": "../data/msmarco_qrels.txt",
    "fira_tuples": "../data/fira-22.tuples.tsv",
    "fira_qrels_baseline": "../data/fira-22.baseline-qrels.tsv",
    "fira_qrels_part1": "../out/fira-22.qrels.tsv",
    "checkpoint_dir": CHECKPOINT_FOLDER_PATH
}

vocab = Vocabulary.from_files(config["vocab_directory"])
tokens_embedder = Embedding(vocab=vocab,
                           pretrained_file= config["pre_trained_embedding"],
                           embedding_dim=300,
                           trainable=True,
                           padding_index=0)
word_embedder = BasicTextFieldEmbedder({"tokens": tokens_embedder})

In [None]:
def find_latest_checkpoint(model_used):
    checkpoint_files = [f for f in os.listdir(config["checkpoint_dir"]) if f.startswith('checkpoint')]
    if not checkpoint_files:
        return None, 0

    epochs = [int(re.search(r'epoch_(\d+)', f).group(1)) for f in checkpoint_files]
    latest_epoch = max(epochs)
    latest_checkpoint = f"checkpoint_{model_used}_epoch_{latest_epoch}.pt"

    return os.path.join(config["checkpoint_dir"], latest_checkpoint)

In [None]:
def train_model(model_used, checkpoint = False, device=("cuda" if torch.cuda.is_available() else "cpu")):

    ##### TRIPLES LOADER ######
    _triple_reader = IrTripleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
    _triple_reader = _triple_reader.read(config["train_data"])
    _triple_reader.index_with(vocab)
    triple_loader = PyTorchDataLoader(_triple_reader, batch_size=32)

    ##### TUPLES VALIDATION MSMARCO LOADER ######
    _tuple_reader_validation = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
    _tuple_reader_validation = _tuple_reader_validation.read(config["validation_data"])
    _tuple_reader_validation.index_with(vocab)
    validation_loader = PyTorchDataLoader(_tuple_reader_validation, batch_size=128)


    output_file = f"../out/test_metrics_msmarco_{model_used}.txt"

    if model_used == "knrm":
            model = KNRM(word_embedder, n_kernels=11)
    elif model_used == "tk":
            model = TK(word_embedder, n_kernels=11, n_layers = 2, n_tf_dim = 300, n_tf_heads = 10)
    else:
        raise Exception("Model type not supported")

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"]) # Adam optimizer

    if(checkpoint):
        checkpoint_path = find_latest_checkpoint(model_used)
        print(checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        #loss = checkpoint['loss']
    else:
        start_epoch = 0

    qrels = load_qrels(config["qrels_file"])
    loss_function = nn.MarginRankingLoss(margin=1.0) # Marking Ranking loss as loss function

    print('Model',model,'total parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))
    print('Network:', model)


    batch_counter = 0

    max_epoch = MAX_EPOCHS
    epochs_trained = 0

    validation_losses = []

    for epoch in range(start_epoch, max_epoch):
        model.train()
        total_loss = 0.0
        batch_counter = 0
        epochs_trained += 1

        for batch in Tqdm.tqdm(triple_loader):
            #if(batch_counter > 2):
            #    break
            batch = {k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items()}
            batch_counter += 1
            optimizer.zero_grad()
            positive = model.forward(batch['query_tokens'], batch['doc_pos_tokens'])
            negative = model.forward(batch['query_tokens'], batch['doc_neg_tokens'])
            target = torch.ones(positive.size()).to(device)  # target tensor with ones, since positive should be ranked higher than negative
            loss = loss_function(positive, negative, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / batch_counter}")

        # checkpoint
        checkpoint_path = os.path.join(config["checkpoint_dir"], f"checkpoint_{model_used}_epoch_{epoch + 1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            #'loss': total_loss / batch_counter,
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

        # Validation Loop
        model.eval()
        rankings = {}
        total_validation_loss = 0.0
        validation_batches = 0
        
        with torch.no_grad():
            for batch in Tqdm.tqdm(validation_loader):
                batch.update({k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items() if k1 in ["query_tokens", "doc_tokens"]})
                outputs = model.forward(batch['query_tokens'], batch['doc_tokens'])

                target = torch.ones(outputs.size()).to(device)
                loss = loss_function(outputs, target, target)
                total_validation_loss += loss.item()
                validation_batches += 1

                for i, query_id in enumerate(batch['query_id']):
                    if query_id not in rankings:
                        rankings[query_id] = []
                    rankings[query_id].append((batch['doc_id'][i], outputs[i].item()))

        #calculate validation loss
        average_validation_loss = total_validation_loss / validation_batches
        validation_losses.append(average_validation_loss)
        print(f"Epoch {epoch + 1}, Validation Loss: {average_validation_loss}")

        # Convert unrolled results to ranked results
        ranked_results = unrolled_to_ranked_result(rankings)

        # Calculate and print validation metrics
        validation_metrics = calculate_metrics_plain(ranked_results, qrels, binarization_point=1)
        print('Validation Metrics:')
        for metric in validation_metrics:
            print('{}: {}'.format(metric, validation_metrics[metric]))

        # # Early stopping check
        # TODO CHECK if early stopping desired
        # if len(validation_losses) > 2:
        #     if validation_losses[-1] > sum(validation_losses[-3:-1]) / 2:
        #         print(f"Stopping early at epoch {epoch + 1}, validation loss increases over 2 epochs average.")
        #         break

In [None]:
# TODO INCLUDE IF TRAINING DESIRED
# checkpoint=True if training should be started from latest available checkpoint

#train_model("knrm", checkpoint=False) 
#train_model("tk", checkpoint=False)

## Load model with lowest validation loss

In [None]:
def calculate_validation_loss(model_path, model_used, device=("cuda" if torch.cuda.is_available() else "cpu")):
    #load model
    if model_used == "knrm":
            model = KNRM(word_embedder, n_kernels=11)
    elif model_used == "tk":
            model = TK(word_embedder, n_kernels=11, n_layers = 2, n_tf_dim = 300, n_tf_heads = 10)
    else:
        raise Exception("Model type not supported")
    
    model.to(device)

    #check if model exists
    
    try:
        model_checkpoint = torch.load(model_path, map_location=device)
    except FileNotFoundError:
         print(f"Path to model does not exist: {model_path}")
         return 10000
    
    model.load_state_dict(model_checkpoint['model_state_dict'])

    # set validation loader
    _tuple_reader_validation = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
    _tuple_reader_validation = _tuple_reader_validation.read(config["validation_data"])
    _tuple_reader_validation.index_with(vocab)
    validation_loader = PyTorchDataLoader(_tuple_reader_validation, batch_size=128)
    
    model.eval()
    loss_function = nn.MarginRankingLoss(margin=1.0)
    rankings = {}
    total_validation_loss = 0.0
    validation_batches = 0
        
    with torch.no_grad():
        for batch in Tqdm.tqdm(validation_loader):
            batch.update({k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items() if k1 in ["query_tokens", "doc_tokens"]})
            outputs = model.forward(batch['query_tokens'], batch['doc_tokens'])

            target = torch.ones(outputs.size()).to(device)
            loss = loss_function(outputs, target, target)
            total_validation_loss += loss.item()
            validation_batches += 1

            for i, query_id in enumerate(batch['query_id']):
                if query_id not in rankings:
                    rankings[query_id] = []
                rankings[query_id].append((batch['doc_id'][i], outputs[i].item()))

    #validation loss
    return total_validation_loss / validation_batches

In [None]:
# returns model with lowest validation loss
def get_best_model(model_type):
    best_validation_loss = 10000
    best_model_path = None
    best_epoch = None

    for i in range (1, MAX_EPOCHS+1):
        current_checkpoint_path = os.path.join(CHECKPOINT_FOLDER_PATH, f"checkpoint_{model_type}_epoch_{i}.pt")
        print(f"Loading epoch {i} of {model_type} model, searching for path {current_checkpoint_path}")
        current_validation_loss = calculate_validation_loss(current_checkpoint_path, model_type)
        print(f"validation loss in epoch {i}: {current_validation_loss}")

        if(current_validation_loss < best_validation_loss):
            best_model_path = current_checkpoint_path
            best_validation_loss = current_validation_loss
            best_epoch = i


    print(f"model with lowest validation loss of {model_type} was in epoch {best_epoch}: {best_model_path}")

    return best_model_path


In [None]:
#best model = model with lowest validation loss
best_model_path_knrm = get_best_model("knrm")
best_model_path_tk = get_best_model("tk")

In [None]:
#tk is better, load best tk model
best_model = TK(word_embedder, n_kernels=11, n_layers = 2, n_tf_dim = 300, n_tf_heads = 10)
model_checkpoint = torch.load(best_model_path_tk, map_location=device)

best_model.load_state_dict(model_checkpoint['model_state_dict'])
best_model.to(device)

## Testing

In [None]:
def test_model_msmarco(model, device=("cuda" if torch.cuda.is_available() else "cpu")):

    _tuple_reader = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
    _tuple_reader = _tuple_reader.read(config["test_data"])
    _tuple_reader.index_with(vocab)
    test_loader = PyTorchDataLoader(_tuple_reader, batch_size=128)

    model.eval()
    rankings = {}
    with torch.no_grad():
        for batch in Tqdm.tqdm(test_loader):
            batch.update({k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items() if k1 in ["query_tokens", "doc_tokens"]})
            outputs = model.forward(batch['query_tokens'], batch['doc_tokens'])
            for i, query_id in enumerate(batch['query_id']):
                if query_id not in rankings:
                    rankings[query_id] = []
                rankings[query_id].append((batch['doc_id'][i], outputs[i].item()))

    # Convert unrolled results to ranked results
    ranked_results = unrolled_to_ranked_result(rankings)

    qrels = load_qrels(config["qrels_file"])

    # Calculate and print test metrics
    test_metrics = calculate_metrics_plain(ranked_results, qrels, binarization_point=1)
    print('#####################')
    print("metrics on msmacro testset")
    for metric in test_metrics:
        print('{}: {}'.format(metric, test_metrics[metric]))
    print('#####################')

    # Store test metrics in a txt file
    output_file = "../out/test_metrics_msmarco_tk_best.txt"

    if os.path.exists(output_file):
        os.remove(output_file)

    with open(output_file, "w") as f:
        f.write('#####################\n')
        f.write(f"Dataset: msmarco\n")
        f.write(f"Model used: {model}\n")
        f.write('#####################\n')
        for metric in test_metrics:
            f.write('{}: {}\n'.format(metric, test_metrics[metric]))
        f.write('#####################\n')

In [None]:
test_model_msmarco(best_model)

### FiRA 

In [None]:
def test_model_fira(model, qrels_path, judgement_type, chosen_binarization_point, device=("cuda" if torch.cuda.is_available() else "cpu")):
    ##### TUPLES FIRA LOADER ######
    _tuple_reader_fira = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
    _tuple_reader_fira = _tuple_reader_fira.read(config["fira_tuples"])
    _tuple_reader_fira.index_with(vocab)
    fira_loader = PyTorchDataLoader(_tuple_reader_fira, batch_size=128)


    #### calculate for baseline
    baseline_qrels = load_qrels(qrels_path)
    config["output_file_fira_baseline"] = f"../out/test_metrics_fira_{judgement_type}_binarizationpoint_{chosen_binarization_point}.txt"


    model.eval()
    rankings = {}
    predictions = []
    with torch.no_grad():
        for batch in Tqdm.tqdm(fira_loader):
            batch.update({k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items() if k1 in ["query_tokens", "doc_tokens"]})
            outputs = model.forward(batch['query_tokens'], batch['doc_tokens'])
            for i, query_id in enumerate(batch['query_id']):
                if query_id not in rankings:
                    rankings[query_id] = []
                rankings[query_id].append((batch['doc_id'][i], outputs[i].item()))
                predictions.append((query_id, batch['doc_id'][i], outputs[i].item()))

    # Convert unrolled results to ranked results
    ranked_results = unrolled_to_ranked_result(rankings)

    # Calculate and print FiRA test metrics with baseline qrels
    fira_test_metrics = calculate_metrics_plain(ranked_results, baseline_qrels, binarization_point=chosen_binarization_point)
    print('#####################')
    print(f'FiRA Test Metrics with Judgement {judgement_type}:')
    for metric in fira_test_metrics:
        print('{}: {}'.format(metric, fira_test_metrics[metric]))
    print('#####################')

    # Store predictions in a txt file
    predictions_file = f"../out/fira_predictions_{judgement_type}_tk_binarizationpoint_{chosen_binarization_point}.txt"
    with open(predictions_file, "w", encoding='utf-8') as f:
        for query_id, doc_id, score in predictions:
            f.write(f"{query_id}\t{doc_id}\t{score}\n")

    output_file_fira = f"../out/fira_metrics_tk_judgementtype_{judgement_type}_binarizationpoint_{chosen_binarization_point}.txt"

    # Store evaluation metric of FiRA Baseline in a txt file
    with open(output_file_fira, "w", encoding='utf-8') as f:
        f.write('#####################\n')
        f.write(f"Dataset: FiRA\n")
        f.write(f"Model used: tk\n")
        f.write(f"binarization point: {chosen_binarization_point}\n")
        f.write('#####################\n')
        f.write(f'FiRA Test Metrics with Judgement {judgement_type}:\n')
        for metric in fira_test_metrics:
            f.write('{}: {}\n'.format(metric, fira_test_metrics[metric]))
        f.write('#####################\n')

In [None]:
#test fira with both judgements with binarization points 1 and 2
test_model_fira(best_model, config['fira_qrels_baseline'], 'baseline', 1)
test_model_fira(best_model, config['fira_qrels_baseline'], 'baseline', 2)
test_model_fira(best_model, config['fira_qrels_part1'], 'part1', 1)
test_model_fira(best_model, config['fira_qrels_part1'], 'part1', 2)

### create passage results needed for part 3

In [None]:
model = best_model #set best tk model

msmarco_fira_qa_path = "../data/msmarco-fira-21.qrels.qa-answers.tsv"

_tuple_reader_fira = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
_tuple_reader_fira = _tuple_reader_fira.read(msmarco_fira_qa_path)
_tuple_reader_fira.index_with(vocab)
fira_loader = PyTorchDataLoader(_tuple_reader_fira, batch_size=128)

_tuple_reader = IrLabeledTupleDatasetReader(lazy=True, max_doc_length=180, max_query_length=30)
_tuple_reader = _tuple_reader.read(config["test_data"])
_tuple_reader.index_with(vocab)
test_loader = PyTorchDataLoader(_tuple_reader, batch_size=128)

model.eval()
rankings = {}
with torch.no_grad():
    for batch in Tqdm.tqdm(test_loader):
        batch.update({k1: {k2: {k3: v3.to(device) for k3, v3 in v2.items()} for k2, v2 in v1.items()} for k1, v1 in batch.items() if k1 in ["query_tokens", "doc_tokens"]})
        outputs = model.forward(batch['query_tokens'], batch['doc_tokens'])
        for i, query_id in enumerate(batch['query_id']):
            if query_id not in rankings:
                rankings[query_id] = []
            rankings[query_id].append((batch['doc_id'][i], outputs[i].item()))

# Convert unrolled results to ranked results
ranked_results = unrolled_to_ranked_result(rankings)
qrels = load_qrels(config["qrels_file"])


In [None]:
#save ranked results

output_file_predictions = "../out/msmarco_predictions_qa-answer-top1.txt"

if os.path.exists(output_file_predictions):
    os.remove(output_file_predictions)

with open(output_file_predictions, 'w') as f:
    f.write("query_id\tdoc_id\n")
    for key, value in ranked_results.items():
        if value:
            f.write(f"{key}\t{value[0]}\n") #only keep first doc_id (top 1)