This notebook has 3 purposes:
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Architecture

In [1]:
from torch_geometric import nn
from torch_geometric.nn import GATConv
import torch
from torch.nn import SiLU
from DistMultMod import DistMultMod

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [2]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, data):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMultMod(data.num_nodes, data.num_edges, hidden_channels, data)

    def forward(self, x, edge_index, head_indices, relations, tail_indices):
        x = self.Encoder(x,edge_index)
        self.Decoder.node_emb = x
        return torch.sigmoid(self.Decoder(head_indices, relations, tail_indices))
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)

We set up a test model to make sure that our architecture is valid

In [3]:
# testModel = KGLinkPredictor(8,8,5,3)

# testEmbedding = torch.randn(5,8)
# testEdges = torch.tensor([[0,1,2,3,4],[1,2,3,4,0]])

# testPrediction = testModel(testEmbedding, testEdges, torch.tensor([0,0,2]), torch.tensor([0,0,0]), torch.tensor([2,1,3]))

# display(testPrediction)

# Training

First, we do a full-relation pre-training

In [4]:
# def train(dataLoader, model, optimizer, device):

Next, we do focused drug-disease relation fine-tuning

In [5]:
# def finetune():
    # Will complete after Dataset completion

# Test

In [6]:
from Dataset import processData
from torch_geometric.nn import to_hetero

First we import and process our data

In [7]:
data, ptrain_loader, pval_loader, ptest_loader, ftrain_loader, fval_loader, ftest_loader = processData(128,128)
display(data)
print(len(ptrain_loader), len(pval_loader), len(ptest_loader))
print(len(ftrain_loader), len(fval_loader), len(ftest_loader))

  kg = pd.read_csv(path + r'\data\kg.csv')


HeteroData(
  gene_protein={ x=[27671, 128] },
  drug={ x=[7957, 128] },
  effect_phenotype={ x=[15311, 128] },
  disease={ x=[17080, 128] },
  biological_process={ x=[28642, 128] },
  molecular_function={ x=[11169, 128] },
  cellular_component={ x=[4176, 128] },
  exposure={ x=[818, 128] },
  pathway={ x=[2516, 128] },
  anatomy={ x=[14035, 128] },
  (anatomy, anatomy_anatomy, anatomy)={ edge_index=[2, 28064] },
  (gene_protein, anatomy_protein_absent, anatomy)={ edge_index=[2, 39774] },
  (gene_protein, anatomy_protein_present, anatomy)={ edge_index=[2, 3036406] },
  (biological_process, bioprocess_bioprocess, biological_process)={ edge_index=[2, 105772] },
  (gene_protein, bioprocess_protein, biological_process)={ edge_index=[2, 289610] },
  (cellular_component, cellcomp_cellcomp, cellular_component)={ edge_index=[2, 9690] },
  (gene_protein, cellcomp_protein, cellular_component)={ edge_index=[2, 166804] },
  (drug, contraindication, disease)={ edge_index=[2, 61350] },
  (disease, d

50629 6329 6329
501 63 63


We will run the model once on our data

In [8]:
testModel = KGLinkPredictor(128,8,data)
testModel.Encoder = to_hetero(testModel.Encoder,data.metadata())

with torch.no_grad():
    out = testModel(data.x_dict, data.edge_index_dict, torch.tensor([0,1,2,3,4,5,6,7,8,9]), torch.tensor([1,3,5,7,9,0,2,4,6,8]), torch.tensor([100,200,300,400,500,600,700,800,900,1000]))
out

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In [9]:
testModel.loss(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))

IndexError: index 17062 is out of bounds for dimension 0 with size 14035