In [3]:
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

from data import EncoderDataset, GraphDataset, RegressionDataset
from encoder import EncoderModel
from graph import GraphModel
from model import Model
from regression import RegressionModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data_path = "../data/data.feather"
tokenizer_name = "bert-base-uncased"
encoder_dataset = EncoderDataset(data_path, tokenizer_name)
encoder_loader = DataLoader(
    encoder_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=encoder_dataset.collate_fn
)

In [5]:
encoder_config = {
    "n_layers": 2,
    "n_heads": 2,
    "n_hidden": 128,
}
encoder_model_name = "allenai/scibert_scivocab_uncased"
encoder_model = EncoderModel(encoder_model_name, encoder_config)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
embeddings = torch.stack([
    encoder_model(batch) for batch in encoder_loader
])

In [None]:
graph_dataset = GraphDataset(encoder_model, encoder_dataset.data)
graph_loader = DataLoader(graph_dataset, shuffle=False)

In [None]:
graph_model = GraphModel()

In [None]:
regression_dataset = RegressionDataset(encoder_dataset.data, graph_model.embeddings)
regression_loader = DataLoader(regression_dataset, shuffle=False)

In [None]:
regression_model = RegressionModel()

In [None]:
model = Model(encoder_model, graph_model, regression_model)

In [None]:
model.setup(encoder_loader)

In [None]:
model.train()