In [1]:
import torch
from transformers import AutoTokenizer

from load import load_dataset, load_test_dataset
from models.baseline import BaselineModel, get_embeddings
from utils import solution_from_embeddings, get_metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [2]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_loader, val_loader = load_dataset(tokenizer)

model = BaselineModel(
    model_name=model_name,
    num_node_features=300,
    nout=768,
    nhid=300,
    graph_hidden_channels=300,
).to(
    device
)  # nout = bert model hidden dim

In [3]:
save_path = "./outputs/saved/model5.pt"
print("Loading best model...")
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Loading best model...


BaselineModel(
  (graph_encoder): GraphEncoder(
    (relu): ReLU()
    (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (conv1): GCNConv(300, 300)
    (conv2): GCNConv(300, 300)
    (conv3): GCNConv(300, 300)
    (mol_hidden1): Linear(in_features=300, out_features=300, bias=True)
    (mol_hidden2): Linear(in_features=300, out_features=768, bias=True)
  )
  (text_encoder): TextEncoder(
    (bert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
       

In [4]:
test_loader, test_text_loader = load_test_dataset(tokenizer)

graph_embeddings, text_embeddings = get_embeddings(
    model.get_graph_encoder(),
    model.get_text_encoder(),
    test_loader,
    test_text_loader,
    device,
)

  value = torch.cat(values, dim=cat_dim or 0, out=out)


In [5]:
solution_from_embeddings(
    graph_embeddings, text_embeddings, save_to=f"solution_{model_name}.csv"
)

In [6]:
print(get_metrics(model, val_loader))

  value = torch.cat(values, dim=cat_dim or 0, out=out)


(0.43589087105535257, 0.4194893265937901)
