This notebook has 3 purposes:
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Architecture

In [1]:
import torch_geometric
from torch_geometric import nn
from torch_geometric.nn import GATConv
from torch_geometric.nn.kge import DistMult
import torch
from torch.nn import SiLU
from DistMultMod import DistMultMod

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [2]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels,num_nodes,num_relations):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMultMod(num_nodes, num_relations, hidden_channels)

    def forward(self, x, edge_index, head_indices, relations, tail_indices):
        x = self.Encoder(x,edge_index)
        self.Decoder.node_emb = x
        return torch.sigmoid(self.Decoder(head_indices, relations, tail_indices))
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)

We set up a test model to make sure that our architecture is valid

In [142]:
testModel = KGLinkPredictor(8,8,5,3)

testEmbedding = torch.randn(5,8)
testEdges = torch.tensor([[0,1,2,3,4],[1,2,3,4,0]])

testPrediction = testModel(testEmbedding, testEdges, torch.tensor([0,0,2]), torch.tensor([0,0,0]), torch.tensor([2,1,3]))

display(testPrediction)

tensor([0.4979, 0.4988, 0.4964], grad_fn=<SigmoidBackward0>)

# Training

First, we do a full-relation pre-training

In [None]:
def pretrain():
    # Will complete after Dataset completion

Next, we do focused drug-disease relation fine-tuning

In [None]:
def finetune():
    # Will complete after Dataset completion