This notebook has 3 purposes:
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Architecture

In [2]:
import torch_geometric
from torch_geometric import nn
from torch_geometric.nn import GATConv
from torch_geometric.nn.kge import DistMult
import torch
from torch.nn import SiLU

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [3]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels,num_nodes,num_relations):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMult(num_nodes, num_relations, 1)

    def forward(self, x, edge_index):
        x = self.Encoder(x,edge_index)
        return self.Decoder
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)

Test that model architecture is valid

In [45]:
testModel = KGLinkPredictor(8,4,5,3)

testEmbedding = torch.randn(5,8)
testEdges = torch.tensor([[0,1,2,3,4],[1,2,3,4,0]])

testEmbedding = testModel.Encoder(testEmbedding,testEdges)
# torch.sigmoid(testModel.Decoder(testEmbedding[0],0,testEmbedding[1]))



loss = testModel.loss(torch.tensor([0,1,2,3,4]),torch.tensor([0,1,2,0,1]),torch.tensor([1,2,3,4,0]))


display(testEmbedding)
display(loss)

tensor([[-0.0366, -0.0850, -0.0413,  0.0348],
        [ 0.0393, -0.1528,  0.0971,  0.1557],
        [ 0.0674, -0.1946,  0.1168,  0.2126],
        [ 0.0333, -0.2110,  0.0365,  0.1925],
        [-0.0348, -0.1303, -0.0739,  0.0702]], grad_fn=<AddBackward0>)

tensor(0.9725, grad_fn=<MeanBackward0>)

# Training

First, we do a full-relation pre-training

In [None]:
def pretrain():
    # Will complete after Dataset completion

Next, we do focused drug-disease relation fine-tuning

In [None]:
def finetune():
    # Will complete after Dataset completion