# One sided learning

In some cases we may want to train only the graph encoder with the initial text embeddings, because it may require a lot of iterations to get a good representation of the graphs. This way it will train much faster than if the text encoder is trained simultaneously. Obviously the overall performances will be lower, but we can then fine-tune the whole model. Hopefully this will allow us to get a better graph encoder.

In [None]:
import torch
from torch import optim
from transformers import AutoTokenizer

import numpy as np
from torch_geometric.loader import DataLoader
from transformers import PreTrainedTokenizer

import torch.nn as nn
from utils import train
from models.baseline import TextEncoder
from models.gat.gat import GATEncoder
from datasets import OneSidedDataset
from models.gat import GATModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


%load_ext autoreload
%autoreload 2

%load_ext tensorboard

## Train only the graph encoder

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_dim = 384
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_encoder = TextEncoder(model_name).to(device)

Create a new dataset with the text embeddings :

In [None]:
def load_dataset(
    tokenizer: PreTrainedTokenizer,
    text_encoder: nn.Module,
    device: torch.device,
    batch_size: int = 32,
    root=".",
    features=[],
    shuffle=True,
):
    gt = np.load(f"{root}/data/token_embedding_dict.npy", allow_pickle=True)[()]
    train_dataset = OneSidedDataset(
        root=f"{root}/data/",
        gt=gt,
        split="train",
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        device=device,
        features=features,
    )
    val_dataset = OneSidedDataset(
        root=f"{root}/data/",
        gt=gt,
        split="val",
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        device=device,
        features=features,
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4
    )

    return train_loader, val_loader


train_loader, val_loader = load_dataset(tokenizer, text_encoder, device)

In [None]:
class OneSidedModel(nn.Module):
    def __init__(
        self,
        graph_encoder,
    ):
        super(OneSidedModel, self).__init__()
        self.graph_encoder = graph_encoder

    def forward(self, graph_batch, input_ids, attention_mask):
        graph_encoded = self.graph_encoder(graph_batch)
        return graph_encoded, graph_batch.y

    def get_text_encoder(self):
        raise NotImplementedError

    def get_graph_encoder(self):
        return self.graph_encoder

In [None]:
graph_encoder = GATEncoder(
    300,
    embeddings_dim,
).to(device)

model = OneSidedModel(graph_encoder).to(device)

In [None]:
optimizer = optim.AdamW(
    model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0.01
)

save_path, _, _ = train(
    model,
    optimizer,
    train_loader,
    val_loader,
    nb_epochs=50,
    device=device,
    save_name="one_side",
)

## Finish training the full model

In [None]:
load_from = "./outputs/one_side26.pt"
save_path = "./outputs/one_side26_full.pt"

checkpoint = torch.load(load_from)

full_model = GATModel(
    model_name=model_name,
    num_node_features=300,
    nout=embeddings_dim,
).to(device)

# change keys of checkpoint to remove 'graph_encoder.' prefix
new_state_dict = {}
for k, v in checkpoint["model_state_dict"].items():
    if k.startswith("graph_encoder."):
        name = k[14:]  # remove 'graph_encoder.' prefix
        new_state_dict[name] = v
    else:
        new_state_dict[k] = v

full_model.graph_encoder.load_state_dict(new_state_dict)

optimizer = optim.AdamW(
    full_model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0.01
)

# save this checkpoint
torch.save(
    {
        "epoch": checkpoint["epoch"],
        "model_state_dict": full_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "val_loss": checkpoint["val_loss"],
        "val_score": checkpoint["val_score"],
    },
    save_path,
)

In [None]:
save_path, _, _ = train(
    full_model,
    optimizer,
    train_loader,
    val_loader,
    nb_epochs=50,
    device=device,
    load_from=save_path,
    save_name="one_side_full",
)