This notebook has 4 Sections:
* Data processing
* Setting up the model architecture
* Defining the training loop
* Running a limited test of the model

# Data Processing

In [1]:
from Deprecated_Dataset import processData
from torch_geometric.nn import to_hetero
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch_geometric import nn
from torch_geometric.nn import GATConv
import torch
from torch.nn import SiLU
from DistMultMod import DistMultMod
import numpy as np
from tqdm.notebook import tqdm
import os

Use of the general data processing code to produce the graph and dataloaders

In [18]:
data, ptrain_loader, pval_loader, ptest_loader, ftrain_loader, fval_loader, ftest_loader, local_idx_map = processData(256,2048)
display(data)
print(len(ptrain_loader), len(pval_loader), len(ptest_loader))
print(len(ftrain_loader), len(fval_loader), len(ftest_loader))

  kg = pd.read_csv(path + r'/data/kg.csv')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  raw_pretrain_indices['relation'] = raw_pretrain_indices['relation'].map(name_to_num)


HeteroData(
  gene_protein={ x=[27671, 256] },
  drug={ x=[7957, 256] },
  effect_phenotype={ x=[15311, 256] },
  disease={ x=[17080, 256] },
  biological_process={ x=[28642, 256] },
  molecular_function={ x=[11169, 256] },
  cellular_component={ x=[4176, 256] },
  exposure={ x=[818, 256] },
  pathway={ x=[2516, 256] },
  anatomy={ x=[14035, 256] },
  (anatomy, anatomy_anatomy, anatomy)={ edge_index=[2, 28064] },
  (anatomy, anatomy_protein_absent, gene_protein)={ edge_index=[2, 19887] },
  (anatomy, anatomy_protein_present, gene_protein)={ edge_index=[2, 1518203] },
  (biological_process, bioprocess_bioprocess, biological_process)={ edge_index=[2, 105772] },
  (biological_process, bioprocess_protein, gene_protein)={ edge_index=[2, 144805] },
  (cellular_component, cellcomp_cellcomp, cellular_component)={ edge_index=[2, 9690] },
  (cellular_component, cellcomp_protein, gene_protein)={ edge_index=[2, 83402] },
  (disease, contraindication, drug)={ edge_index=[2, 30675] },
  (disease, di

2285 286 286
17 3 3


Disease similarity search to improve zero-shot embeddings

In [4]:
# Generate the one-hot vectors for disease similarity computation
def generateOverallOneHot(disease_idx,data):
    
    # Generate a generic one-hot vector for a condition
    def generateOneHot(node_idx, edge_type, data):
        
        # Get neighbors of the node
        edges = np.array(data[edge_type].edge_index)
        mask = np.where(edges[0] == node_idx)[0]
        neighbors = edges[1][mask]
        
        # Generate the one-hot vector
        one_hot = np.zeros(data[edge_type[2]].num_nodes,dtype=int)
        one_hot[neighbors] = 1
        return one_hot
    
    # Generate the one-hot vectors for the disease with important neighbors and concatenate them
    geneOneHot = generateOneHot(disease_idx,('disease','disease_protein','gene_protein'),data)
    effectOneHot = generateOneHot(disease_idx,('disease', 'disease_phenotype_negative', 'effect_phenotype'),data) + generateOneHot(disease_idx,('disease', 'disease_phenotype_positive', 'effect_phenotype'),data)
    exposureOneHot = generateOneHot(disease_idx,('disease', 'exposure_disease', 'exposure'),data)
    diseaseOneHot = generateOneHot(disease_idx,('disease','disease_disease','disease'),data)
    overallOneHot = torch.tensor(np.hstack([geneOneHot,effectOneHot,exposureOneHot,diseaseOneHot]))
    
    return overallOneHot

def constructDiseaseSimilarity(k=10):
    
    # Get the number of diseases and the number of possible neighbors
    num_diseases = data['disease'].num_nodes
    num_possible_neighbors = data['gene_protein'].num_nodes + data['effect_phenotype'].num_nodes + data['exposure'].num_nodes + data['disease'].num_nodes
    
    # Generate the one-hot vectors for all diseases
    oneHots = torch.zeros(num_diseases,num_possible_neighbors)
    for disease_idx in tqdm(range(num_diseases)):
        oneHots[disease_idx] = generateOverallOneHot(disease_idx,data)
    
    # Compute the similarity between diseases and store the top-k most similar diseases
    disease_similarity_storage = torch.zeros(num_diseases,2,k)
    similarity_matrix = torch.zeros(num_diseases,num_diseases)
    
    for query_disease in tqdm(range(num_diseases)):
        
        queryOneHot = oneHots[query_disease]
        
        for key_disease in range(query_disease+1,num_diseases):
            
            keyOneHot = oneHots[key_disease]
            similarity = torch.dot(queryOneHot,keyOneHot)
            similarity_matrix[query_disease][key_disease] = similarity
            similarity_matrix[key_disease][query_disease] = similarity
    
    for query_disease in tqdm(range(num_diseases)):
        
        similarity = similarity_matrix[query_disease]
        
        # Get the top-k most similar diseases to the query disease and store them    
        topk = torch.topk(similarity_matrix[query_disease],k)
        if torch.sum(topk.values) == 0:
            disease_similarity_storage[query_disease][0] = torch.zeros(k)
        else:
            disease_similarity_storage[query_disease][0] = topk.values / torch.sum(topk.values)
        disease_similarity_storage[query_disease][1] = topk.indices
        
    return disease_similarity_storage

11083 11269 11107

In [5]:
if 'disease_similarity.pt' not in os.listdir():
    torch.save(constructDiseaseSimilarity(),'disease_similarity.pt') # Takes about 40 mins to run, only need to run once

# Architecture

This model will have a GAT-based encoder to produce node embeddings, with a DistMult-based decoder for the purpose of link prediction

In [6]:
class KGLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, data, similarity_path, local_idx_map):
        super(KGLinkPredictor, self).__init__()
        
        self.Encoder = nn.Sequential('x, edge_index', [
            (GATConv(in_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x'),
            SiLU(inplace=True),
            (GATConv(hidden_channels,hidden_channels,add_self_loops=False), 'x, edge_index -> x')
        ])
        
        self.Decoder = DistMultMod(data.num_nodes, data.num_edges, hidden_channels, data, similarity_path, local_idx_map)
        
        self.data = data

    def forward(self, head_indices, relations, tail_indices):
        x = self.Encoder(self.data.x_dict,self.data.edge_index_dict)
        self.Decoder.node_emb = torch.vstack([*x.values()])
        return torch.sigmoid(self.Decoder(head_indices, relations, tail_indices))
    
    def loss(self, head_index, relation, tail_index):
        return self.Decoder.loss(head_index, relation, tail_index)

# Training

First, we do a full-relation pre-training

In [7]:
def train(data, train_dataLoader, val_dataloader, model, optimizer, scheduler, device):
    train_loss = 0
    for i,idx_data in enumerate(train_dataLoader):
        idx_data = idx_data.to(device)
        optimizer.zero_grad()
        model(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        loss = model.loss(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # if i % 10 == 0:
        #     print(f"Train Batch {i+1} loss | {loss.item()}")
        # print(f"Train Batch {i+1} loss | {loss.item()}")
    # scheduler.step()
    
    val_loss = 0
    for i,idx_data in enumerate(val_dataloader):
        idx_data = idx_data.to(device)
        loss = model.loss(idx_data[:,0], idx_data[:,1], idx_data[:,2])
        val_loss += loss.item()
        # print(f"Val Batch {i+1} loss | {loss.item()}")
        
    return train_loss / len(train_dataLoader), val_loss / len(val_dataloader)

# Test

We will run the model once on our data

In [14]:
data.to('cuda')
testModel = KGLinkPredictor(256,64,data,'disease_similarity.pt', local_idx_map).to('cuda')
testModel.Encoder = to_hetero(testModel.Encoder,data.metadata())

with torch.no_grad():
    out = testModel(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))
    loss = testModel.loss(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))

display(out,loss)

tensor([0.1771, 0.8381, 0.2668], device='cuda:0')

tensor(1.0918, device='cuda:0')

Training Test Run for a couple batches

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data.to(device)
testModel.to(device)
testOptimizer = torch.optim.Adam(testModel.parameters(), lr=0.001)
testScheduler = ExponentialLR(testOptimizer, gamma=0.1)

train(data, ptrain_loader, pval_loader, testModel, testOptimizer, testScheduler, device) # I estimate around 25 mins/epoch

(0.12267233642629252, 0.0942043760022917)

In [10]:
testModel.Decoder.reset_parameters()

In [13]:
for epoch in tqdm(range(20)):
    train_loss, val_loss = train(data, ftrain_loader, fval_loader, testModel, testOptimizer, testScheduler, device) # I estimate around 2 mins/epoch
    print(f"Epoch {epoch+1} | Train Loss: {train_loss} | Val Loss: {val_loss}")

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 | Train Loss: 0.04879272641504512 | Val Loss: 0.050193200508753456
Epoch 2 | Train Loss: 0.04923012743101401 | Val Loss: 0.04971953108906746
Epoch 3 | Train Loss: 0.051049404284533334 | Val Loss: 0.056397306422392525
Epoch 4 | Train Loss: 0.04690680565202937 | Val Loss: 0.0409269571925203
Epoch 5 | Train Loss: 0.04652618529165492 | Val Loss: 0.05027736102541288
Epoch 6 | Train Loss: 0.0464226261657827 | Val Loss: 0.04729254295428594
Epoch 7 | Train Loss: 0.043016956133000994 | Val Loss: 0.06325493132074674
Epoch 8 | Train Loss: 0.04566275328397751 | Val Loss: 0.05355566864212354
Epoch 9 | Train Loss: 0.04513511828639928 | Val Loss: 0.04592655785381794
Epoch 10 | Train Loss: 0.051285226117162144 | Val Loss: 0.06168110171953837
Epoch 11 | Train Loss: 0.04599643827361219 | Val Loss: 0.048447027802467346


KeyboardInterrupt: 

We see the embeddings being modified

In [16]:
display("Meaningful embedding ", testModel.Encoder(testModel.data.x_dict,testModel.data.edge_index_dict)['gene_protein'][0])
display("Randomly initialized tensor ", testModel.data.x_dict['gene_protein'][0])

'Meaningful embedding '

tensor([-34.0878, -10.9848,   3.2778,  15.8652,   1.1965, -57.8402,  14.5811,
         -3.2915,  12.9621,  11.9062,  -6.9801,   9.7309,   6.8755,   6.1691,
         -8.6140, -11.2621,  -5.6652,   1.9878,  -5.0708, -11.2943,   8.9231,
          8.1831,  12.9974, -33.6114,  -3.9108,   2.9071, -14.2387,  52.5912,
         -2.3146,  13.5081,  -6.8878,  52.1212,  26.7223, -17.1594, -23.2313,
         27.5976,  19.7402,  -4.6048, -11.9684,  -8.2866,  -8.7745,   1.0524,
        -20.5567,  17.9088,  19.5952,  16.3532,  -6.2202, -16.1388,  -0.0686,
         25.3577,  -7.3419,  -0.4495,  33.4917,   8.2671,   3.1699,  11.9944,
         -3.9773,  -2.8941, -16.1103,   0.1276,  -4.9487, -17.1192,   2.1534,
         -5.7091], device='cuda:0', grad_fn=<SelectBackward0>)

'Randomly initialized tensor '

tensor([-7.9961e-03, -6.1600e-03, -3.3313e-03, -2.1216e-03,  2.9461e-03,
        -1.4158e-03,  4.1348e-03,  1.4422e-02, -2.8717e-03,  1.2612e-03,
        -6.2864e-03, -5.6060e-03, -1.4166e-02, -8.7319e-03,  7.5883e-03,
        -1.6285e-03, -4.0831e-03,  7.6047e-03,  4.7692e-03, -3.3069e-03,
         3.9118e-03,  1.0287e-03,  1.1578e-02,  1.4036e-02,  1.3397e-02,
         1.4138e-02, -1.0554e-05, -7.0151e-03, -4.7827e-04,  7.1407e-04,
         5.6455e-03, -6.3677e-03, -1.1812e-02, -2.7951e-03, -3.0000e-03,
        -1.3138e-02, -5.3270e-04, -1.0024e-02, -7.1597e-03, -1.0843e-02,
         3.4279e-03,  1.4008e-02, -5.3258e-03,  1.3800e-02,  9.4782e-03,
         1.3524e-02, -3.5637e-03,  5.8734e-03,  3.4998e-03,  5.4336e-03,
        -1.1303e-02, -4.7218e-03, -1.0765e-04,  1.2134e-02, -5.3067e-03,
         7.1311e-04,  1.0168e-02, -3.8810e-03,  6.9090e-03,  7.4683e-03,
         5.2419e-03,  9.5700e-03, -9.6588e-03, -1.0856e-02,  1.4338e-02,
         1.1300e-02, -9.1476e-03, -9.2345e-03, -1.7

Edge predictions change

In [15]:
with torch.no_grad():
    display(torch.sigmoid(testModel.Decoder(torch.tensor([0,0,2]), torch.tensor([0,1,2]), torch.tensor([2,1,3]))))

tensor([0.1713, 0.8381, 0.3036], device='cuda:0')