This notebook has 3 purposes:
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Architecture

In [1]:
from torch_geometric import nn
from torch_geometric.nn import GATConv
import torch
from torch.nn import SiLU
from DistMultMod import DistMultMod

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [2]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, data):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMultMod(data.num_nodes, data.num_edges, hidden_channels, data)
        
        self.data = data

    def forward(self, head_indices, relations, tail_indices):
        x = self.Encoder(self.data.x_dict,self.data.edge_index_dict)
        self.Decoder.node_emb = torch.vstack([*x.values()])
        return torch.sigmoid(self.Decoder(head_indices, relations, tail_indices))
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)
    

# Training

First, we do a full-relation pre-training

In [11]:
def train(data, train_dataLoader, val_dataloader, model, optimizer, scheduler, device):
    train_loss = 0
    for i,idx_data in enumerate(train_dataLoader):
        idx_data = idx_data.to(device)
        optimizer.zero_grad()
        model(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        loss = model.loss(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        loss.backward()
        optimizer.step()
        # scheduler.step()
        train_loss += loss.item() / len(idx_data)
        if i % 10 == 0:
            print("Batch loss ", loss.item())
    
    val_loss = 0
    for idx_data in val_dataloader:
        idx_data = idx_data.to(device)
        loss = model.loss(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        val_loss += loss.item() / len(idx_data)
    return train_loss, val_loss

# Test

In [4]:
from Dataset import processData
from torch_geometric.nn import to_hetero
from torch.optim.lr_scheduler import StepLR

First we import and process our data

In [5]:
data, ptrain_loader, pval_loader, ptest_loader, ftrain_loader, fval_loader, ftest_loader = processData(128,128)
display(data)
print(len(ptrain_loader), len(pval_loader), len(ptest_loader))
print(len(ftrain_loader), len(fval_loader), len(ftest_loader))

  kg = pd.read_csv(path + r'\data\kg.csv')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  pretrain_indices['relation'] = pretrain_indices['relation'].map(name_to_num)


HeteroData(
  gene_protein={ x=[27671, 128] },
  drug={ x=[7957, 128] },
  effect_phenotype={ x=[15311, 128] },
  disease={ x=[17080, 128] },
  biological_process={ x=[28642, 128] },
  molecular_function={ x=[11169, 128] },
  cellular_component={ x=[4176, 128] },
  exposure={ x=[818, 128] },
  pathway={ x=[2516, 128] },
  anatomy={ x=[14035, 128] },
  (anatomy, anatomy_anatomy, anatomy)={ edge_index=[2, 28064] },
  (gene_protein, anatomy_protein_absent, anatomy)={ edge_index=[2, 39774] },
  (gene_protein, anatomy_protein_present, anatomy)={ edge_index=[2, 3036406] },
  (biological_process, bioprocess_bioprocess, biological_process)={ edge_index=[2, 105772] },
  (gene_protein, bioprocess_protein, biological_process)={ edge_index=[2, 289610] },
  (cellular_component, cellcomp_cellcomp, cellular_component)={ edge_index=[2, 9690] },
  (gene_protein, cellcomp_protein, cellular_component)={ edge_index=[2, 166804] },
  (drug, contraindication, disease)={ edge_index=[2, 61350] },
  (disease, d

50629 6329 6329
533 67 67


We will run the model once on our data

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
testModel = KGLinkPredictor(128,8,data).to(device)
data.to(device)
testModel.Encoder = to_hetero(testModel.Encoder,data.metadata())

with torch.no_grad():
    out = testModel(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))
    loss = testModel.loss(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))

display(out,loss)

tensor([0.5000, 0.5000, 0.5000], device='cuda:0')

tensor(1., device='cuda:0')

Training Test Run for a couple batches

In [12]:
data.to(device)
testModel.to(device)
testOptimizer = torch.optim.Adam(testModel.parameters(), lr=0.01)
testScheduler = StepLR(testOptimizer, step_size=10, gamma=0.1)

train(data, ptrain_loader, pval_loader, testModel, testOptimizer, testScheduler, device)

Batch loss  0.35647904872894287
Batch loss  0.3337223529815674
Batch loss  0.35349926352500916
Batch loss  0.3597087264060974
Batch loss  0.3532677888870239
Batch loss  0.28600096702575684
Batch loss  0.27098560333251953
Batch loss  0.371176153421402
Batch loss  0.30356529355049133
Batch loss  0.3972756862640381
Batch loss  0.3064011335372925
Batch loss  0.4131999611854553
Batch loss  0.32202398777008057
Batch loss  0.2591433525085449


KeyboardInterrupt: 

We see the embeddings being modified

In [13]:
display("Meaningful embedding ", testModel.Encoder(testModel.data.x_dict,testModel.data.edge_index_dict)['gene_protein'][0])
display("Randomly initialized tensor ", testModel.data.x_dict['gene_protein'][0])

'Meaningful embedding '

tensor([ -46.5933,  -66.3595,  -71.1544,  -22.4638,  -56.1134,  -27.2626,
         -52.8214, -110.0290], device='cuda:0', grad_fn=<SelectBackward0>)

'Randomly initialized tensor '

tensor([-0.0067, -0.0112, -0.0135,  0.0087,  0.0057,  0.0015,  0.0065,  0.0107,
        -0.0027,  0.0011,  0.0050,  0.0104, -0.0057,  0.0102,  0.0098, -0.0102,
        -0.0023, -0.0041, -0.0139,  0.0055,  0.0128, -0.0024, -0.0025,  0.0011,
        -0.0114,  0.0112,  0.0014,  0.0124,  0.0072, -0.0099, -0.0014, -0.0123,
         0.0093,  0.0028,  0.0014, -0.0006,  0.0099,  0.0052, -0.0079,  0.0010,
         0.0067,  0.0063, -0.0114,  0.0101, -0.0113,  0.0089,  0.0022, -0.0046,
         0.0005,  0.0136,  0.0118,  0.0077, -0.0100,  0.0067, -0.0021,  0.0041,
        -0.0042, -0.0044, -0.0064,  0.0094,  0.0044,  0.0012,  0.0071,  0.0007,
        -0.0083,  0.0007,  0.0105, -0.0110,  0.0099,  0.0079,  0.0043, -0.0049,
         0.0042,  0.0016,  0.0074,  0.0094, -0.0090, -0.0023, -0.0122, -0.0135,
        -0.0092, -0.0136,  0.0013,  0.0036, -0.0107,  0.0105,  0.0144,  0.0091,
         0.0045,  0.0071,  0.0107,  0.0094,  0.0041,  0.0048,  0.0095, -0.0036,
        -0.0026, -0.0104, -0.0072, -0.00

Edge predictions change

In [18]:
with torch.no_grad():
    display(torch.sigmoid(testModel.Decoder(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))))

tensor([0.9742, 0.3328, 0.4530], device='cuda:0')