This notebook has 3 purposes:
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Architecture

In [6]:
import torch_geometric
from torch_geometric import nn
from torch_geometric.nn import GATConv
from torch_geometric.nn.kge import DistMult
import torch
from torch.nn import SiLU

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [8]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels,num_nodes,num_relations):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMult(num_nodes, num_relations, 1)

    def forward(self, x, edge_index):
        return self.Encoder(x,edge_index)
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)

# Training

First, we do a full-relation pre-training

In [None]:
def pretrain():
    # Will complete after Dataset completion

Next, we do focused drug-disease relation fine-tuning

In [None]:
def finetune():
    # Will complete after Dataset completion