In [31]:
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, mean_squared_error, roc_auc_score, average_precision_score, f1_score, precision_recall_curve
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sparse
from torch_geometric.data import HeteroData
from torch_geometric.nn import HANConv, Linear
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import seaborn as sns


In [14]:
FOLDER_PATH = "/home/juan/Work/Midterm project/splited"
X_PATH = os.path.join(FOLDER_PATH, "X_data.npz")
Y_PATH = os.path.join(FOLDER_PATH, "Y_data.npz")
COLS_PATH = os.path.join(FOLDER_PATH, "feature_names.json")

In [21]:
if not os.path.exists(X_PATH) or not os.path.exists(Y_PATH):
    
    FILE_PATTERN = os.path.join(FOLDER_PATH, "processed_*.csv")
    file = sorted(glob.glob(FILE_PATTERN)) # All processed file in the FOLDER_PATH

    # Use first file to determine column structure
    f_df = pd.read_csv(file[0], nrow = 3)
    cols = f_df.columns
    x_cols = [c for c in cols if c.startswith("x ")]
    y_cols = [c for c in cols if c.startswith("y ")]

    with open(COLS_PATH, 'w') as f:
        json.dum(x_cols, f)

    X_list = []
    Y_list = []

    # Convert to Sparse
    for f in files:
        df = pd.read_csv(f)

        x_data = df[x_cols].fillna(0).values.astype(np.float32)
        y_data = df[y_cols].fillna(0).values.astype(np.float32)

        X_list.append(sparse.csr_matrix(x_data))
        Y_list.append(sparse.csr_matrix(y_data))

        del df, x_data, y_data
        
    X_final = sparse.vstack(X_list)
    Y_final = sparse.vstack(Y_list)

    sparse.save_npz(X_PATH, X_final)
    sparse.save_npz(Y_PATH, Y_final)

In [32]:
#LOAD DATA

X = sparse.load_npz(X_PATH)
Y = sparse.load_npz(Y_PATH)

num_patients = X.shape[0]
num_input_features = X.shape[1]
num_reactions = Y.shape[1]

print(num_patients)
print(num_input_features)
print(num_reactions)

# Split
indices = np.arange(num_patients)
# print(indices)
train_idx, test_idx = train_test_split(indices, test_size = 0.2, random_state = 1)

661271
5286
10488


In [33]:
torch.cuda.is_available()

True

In [45]:
class HeteroGraphDataset(Dataset):
    
    def __init__(self, X, Y, indices, num_input_nodes, num_output_nodes):
        self.X = X
        self.Y = Y
        self.indices = indices
        self.num_input_nodes = num_input_nodes
        self.num_output_nodes = num_output_nodes
        
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.indices[idx]

def graph_collate_fn(batch_indices):
    batch_indices = sorted(batch_indices)

    x_sub = X[batch_indices]
    y_sub = Y[batch_indices]

    batch_size = len(batch_indices)

    rows_d, cols_d = x_sub.nonzero()
    edge_index_drug = torch.stack([torch.from_numpy(rows_d), torch.from_numpy(cols_d)]).long

    rows_r, cols_r = y_sub.nonzero()
    edge_index_react = torch.stack([torch.from_numpu(row_r), torch.from_numpy(cols_d)]).long

    data = HeteroData()

    #Nodes
    data['patients'].x = torch.ones((batch_size, 1), dtype=torch.float)
    data['patients'].num_node = batch_size
    data['drugs'].num_nodes = num_input_features
    data['reactions'].num_node = num_reactions

    #Edges
    data['patient', 'takes', 'drug'].edge_index = edge_index_drug
    data['patient', 'has_reaction', 'reaction'].edge_index = edge_index_react

    #Reverse Edges
    data['drug', 'taken_by', 'patient'].edge_index = torch.flip(edge_index_drug, [0])
    data['reaction', 'reaction_in', 'patient'].edge_index = torch.flip(edge_index_react, [0])

    return data


In [46]:
train_ds = HeteroGraphDataset(X, Y, train_idx, num_input_features, num_reactions)
train_ds

<__main__.HeteroGraphDataset at 0x7a49b87e9400>

In [47]:
train_loader = DataLoader(train_ds, batch_size=1024, shuffle=True, collate_fn=graph_collate_fn, num_workers=0)
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7a49b8d61e00>

In [None]:
class HeteroHAN(torch.nn.Module):

    def __init_(self, hidden_channels, out_channels, metadata, num_head=2):
        super().__init__()

        self.han_conv1 = HANConv(in_channels=hidden_channels, out_channels=hidden_channels, heads=num_heads, dropout=0.2, metadata=metadata)

        def forward(self, x_dict, edge_index_dict):
            x = self.han_conv1(x_dict, edge_index_dict)
            x = {k: v.relu() for k, v in x.item()}
            x = self.han_conv1(x, edge_index_dict)
            return x