In [1]:
import torch
import torch.optim as optim
import torch_geometric
from torch.nn.functional import relu, sigmoid, softmax, mse_loss
from torch.nn import Linear, Module, Dropout, MSELoss, CrossEntropyLoss, BatchNorm1d

from tqdm import tqdm

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GATv2Conv
from egnn_pytorch import EGNN

import pandas as pd
import numpy as np

import os
import pickle
import gzip

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
device = 0
device = torch.device("cuda:{}".format(device) if torch.cuda.is_available() else "cpu")

In [2]:
train = pd.read_csv('kiba/train.csv', index_col=0)
val = pd.read_csv('kiba/val.csv', index_col=0)
test = pd.read_csv('kiba/test.csv', index_col=0)

In [3]:
print('Train dim:', train.shape)
print('val dim:', val.shape)
print('test dim:', test.shape)

Train dim: (3449, 3)
val dim: (494, 3)
test dim: (973, 3)


In [4]:
with gzip.open('drug.pkl.gz', 'rb') as f:
    drug = pickle.load(f)

def get_drug_dataloader(drugs, batch_size=100):
    dataset = [drug[i] for i in drugs]
    return DataLoader(dataset, batch_size=batch_size)

def get_protein_dataloader(proteins, batch_size=100):
    dataset = [torch.load('protein_graphs/{}.pt'.format(i)) for i in proteins]
    return DataLoader(dataset, batch_size=batch_size)

In [5]:
batch_size = 1
drug_train_loader = get_drug_dataloader(train['Drug'], batch_size)
drug_val_loader = get_drug_dataloader(val['Drug'], batch_size)
drug_test_loader = get_drug_dataloader(test['Drug'], batch_size)

protein_train_loader = get_protein_dataloader(train['Target_ID'], batch_size)
protein_val_loader = get_protein_dataloader(val['Target_ID'], batch_size)
protein_test_loader = get_protein_dataloader(test['Target_ID'], batch_size)

train_y = DataLoader(torch.Tensor(train['Y'].values).float(), batch_size=batch_size)
val_y = DataLoader(torch.Tensor(val['Y'].values).float(), batch_size=batch_size)
test_y = DataLoader(torch.Tensor(test['Y'].values).float(), batch_size=batch_size)



In [6]:
for drug, protein, true_y in zip(drug_train_loader, protein_train_loader, train_y):
    drug = drug.to(device)
    protein = protein.to(device)
    true_y = true_y.to(device)
    break

In [7]:
layer1 = EGNN(dim = 5).to(device)
layer2 = EGNN(dim = 5).to(device)

feats, coors = layer1(
    protein.x.view(1, len(protein.x), 5), 
    protein.pos.view(1, len(protein.pos), 3).float()
)
feats, coors = layer2(feats, coors)

In [None]:
protein.x.view(1, len(protein.x), 5).shape

In [None]:
protein.pos.view(1, len(protein.pos), 3).float().shape

In [None]:
n = max(drug.edge_index[1])+1
adj_matrix = torch.zeros(n, n, len(drug.edge_attr[0]))

for i in range(drug.edge_index.shape[1]):
    src = drug.edge_index[0, i]
    dest = drug.edge_index[1, i]
    adj_matrix[src, dest] = drug.edge_attr[i]
    adj_matrix[dest, src] = drug.edge_attr[i]

In [None]:
layer1 = EGNN(dim = 5, edge_dim=2)
layer2 = EGNN(dim = 5, edge_dim=2)

In [None]:
feats, coors = layer1(
    drug.x.view(1, len(drug.x), 5).to('cpu'), 
    drug.pos.view(1, len(drug.pos), 3).to('cpu').float(), 
    adj_matrix.view(1, len(drug.pos), len(drug.pos), 2)
)
feats

In [None]:
feats, coors = layer2(feats, coors, adj_matrix.view(1, len(drug.pos), len(drug.pos), 2))
feats

In [None]:
coors