In [None]:
import torch
from torch import optim
from transformers import AutoTokenizer
import os
import time
from typing import Union
from utils import contrastive_loss
import numpy as np
import torch.nn as nn
from sklearn.metrics import label_ranking_average_precision_score
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader as GeometricDataLoader
from load import load_dataset
from models.gat import GATEncoder
from models.baseline import TextEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_dim = 384
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_loader, val_loader = load_dataset(tokenizer)

graph_encoder = GATEncoder(
    num_node_features=300,
    nout=embeddings_dim,
    mlp_hid=1000,
    att_hidden_dim=600,
    att_out_dim=1000,
    nheads=20,
    dropout=0.1,
    alpha=0.02,
    attention_depth=1,
).to(device)

text_encoder = TextEncoder(model_name).to(device)

In [None]:
def get_metrics(
    graph_encoder: nn.Module,
    text_encoder: nn.Module,
    loader: GeometricDataLoader,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    graph_encoder.eval()
    text_encoder.eval()
    graph_embeddings = []
    text_embeddings = []
    loss = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch.input_ids
            batch.pop("input_ids")
            attention_mask = batch.attention_mask
            batch.pop("attention_mask")
            graph_batch = batch

            x_graph = graph_encoder(graph_batch.to(device))
            x_text = text_encoder(input_ids.to(device), attention_mask.to(device))

            loss += contrastive_loss(x_graph, x_text).item()

            for output in x_graph:
                graph_embeddings.append(output.tolist())
            for output in x_text:
                text_embeddings.append(output.tolist())

    similarity = cosine_similarity(text_embeddings, graph_embeddings)

    return (
        loss / len(loader),
        label_ranking_average_precision_score(np.eye(len(similarity)), similarity),
    )


def train(
    graph_encoder: nn.Module,
    text_encoder: nn.Module,
    graph_optimizer: optim.Optimizer,
    text_optimizer: optim.Optimizer,
    train_loader: GeometricDataLoader,
    val_loader: GeometricDataLoader,
    nb_epochs: int = 5,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    load_from: Union[str, None] = None,
    save_name: str = "model",
):
    writer = SummaryWriter()
    epoch = 0
    loss = 0
    loss_averager = 0
    losses = []
    time1 = time.time()
    print_every = 50
    best_validation_loss = 1e100
    best_validation_score = 0

    if load_from is not None:
        checkpoint = torch.load(load_from)
        graph_encoder.load_state_dict(checkpoint["graph_encoder_state_dict"])
        text_encoder.load_state_dict(checkpoint["text_encoder_state_dict"])
        graph_optimizer.load_state_dict(checkpoint["graph_optimizer_state_dict"])
        text_optimizer.load_state_dict(checkpoint["text_optimizer_state_dict"])
        best_validation_loss = checkpoint["val_loss"]
        best_validation_score = checkpoint["val_score"]
        epoch = checkpoint["epoch"]
        print(
            "Loaded model from {}, best_validation_score={}, best validation loss={}".format(
                load_from, best_validation_score, best_validation_loss
            )
        )

    for e in range(epoch + 1, nb_epochs):
        print("--------------------EPOCH {}--------------------".format(e))
        graph_encoder.train()
        # Train text encoder 1/2 of the time and not the first epoch
        if e % 2 == 1:
            print("Training only graph encoder")
            text_encoder.eval()
            graph_encoder.train()
            train_text = False
        else:
            print("Training only text encoder")
            text_encoder.train()
            graph_encoder.eval()
            train_text = True

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch.input_ids
            batch.pop("input_ids")
            attention_mask = batch.attention_mask
            batch.pop("attention_mask")
            graph_batch = batch

            if train_text:
                with torch.no_grad():
                    x_graph = graph_encoder(graph_batch.to(device))
                x_text = text_encoder(input_ids.to(device), attention_mask.to(device))
                current_loss = contrastive_loss(x_graph, x_text)
                text_optimizer.zero_grad()
                current_loss.backward()
                text_optimizer.step()
            else:
                with torch.no_grad():
                    x_text = text_encoder(
                        input_ids.to(device), attention_mask.to(device)
                    )
                x_graph = graph_encoder(graph_batch.to(device))
                current_loss = contrastive_loss(x_graph, x_text)
                graph_optimizer.zero_grad()
                current_loss.backward()
                graph_optimizer.step()

            loss += current_loss.item()
            loss_averager += 1

            if batch_idx % print_every == 0 and batch_idx > 0:
                loss /= loss_averager
                time2 = time.time()
                print(
                    "Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(
                        batch_idx, time2 - time1, loss
                    )
                )
                losses.append(loss)
                writer.add_scalar("Loss/train", loss, e * len(train_loader) + batch_idx)
                loss = 0
                loss_averager = 0

        step = (e + 1) * len(train_loader)

        print(
            "Computing metrics on validation set... (time={:.4f}s)".format(
                time.time() - time1
            )
        )
        val_loss, val_score = get_metrics(
            graph_encoder, text_encoder, val_loader, device=device
        )
        writer.add_scalar("Loss/val", val_loss, step)
        writer.add_scalar("Score/val", val_score, step)

        writer.flush()

        print(
            "Epoch " + str(e) + " finished with val_loss " + str(val_loss),
            "and val_score",
            val_score,
        )

        best_validation_loss = min(best_validation_loss, val_loss)
        best_validation_score = max(best_validation_score, val_score)

        if best_validation_loss == val_loss or best_validation_score == val_score:
            print("Saving checkpoint... ", end="")
            save_path = os.path.join("./outputs/", save_name + str(e) + ".pt")
            torch.save(
                {
                    "epoch": e,
                    "graph_encoder_state_dict": graph_encoder.state_dict(),
                    "text_encoder_state_dict": text_encoder.state_dict(),
                    "graph_optimizer_state_dict": graph_optimizer.state_dict(),
                    "text_optimizer_state_dict": text_optimizer.state_dict(),
                    "val_loss": val_loss,
                    "val_score": val_score,
                },
                save_path,
            )
            print("done : {}".format(save_path))

    writer.close()
    return save_path, best_validation_loss, best_validation_score

In [None]:
graph_optimizer = optim.AdamW(
    graph_encoder.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=0.01
)

text_optimizer = optim.AdamW(
    text_encoder.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=0.01
)

save_path, _, _ = train(
    graph_encoder,
    text_encoder,
    graph_optimizer,
    text_optimizer,
    train_loader,
    val_loader,
    nb_epochs=50,
    device=device,
    save_name="alternated",
    load_from="./outputs/alternated_fix_30.pt",
)