In [None]:
import torch
from transformers import AutoTokenizer

from load import load_dataset, load_test_dataset
from models.baseline import get_embeddings
from models.gat import GATModel
from utils import solution_from_embeddings, get_metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_dim = 384
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_loader, val_loader = load_dataset(tokenizer)

model = GATModel(
    model_name=model_name,
    num_node_features=300,
    nout=embeddings_dim,
    nhid=600,
    graph_hidden_channels=600,
).to(device)

In [None]:
save_path = (
    "./outputs/saved/sentence-transformers+all-MiniLM-L6-v2 + GAT 10 heads/model13.pt"
)
print("Loading best model...")
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
test_loader, test_text_loader = load_test_dataset(tokenizer)

graph_embeddings, text_embeddings = get_embeddings(
    model.get_graph_encoder(),
    model.get_text_encoder(),
    test_loader,
    test_text_loader,
    device,
)

In [None]:
solution_from_embeddings(graph_embeddings, text_embeddings, save_to=f"solution.csv")

In [None]:
print(get_metrics(model, val_loader, device=device))