In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool



In [4]:
import pickle
from tqdm import tqdm

def serialize(obj, path):
    with open(path, 'wb') as fh:
        pickle.dump(obj, fh)

def deserialize(path):
    with open(path, 'rb') as fh:
        return pickle.load(fh)

In [16]:
networkx_path = "/scr/biggest/gmachi/datasets/celldive_lung/for_ml/for_prospect"
# save_path = "/scr/biggest/gmachi/datasets/celldive_lung/for_ml/for_gat"

In [6]:
# for graph in os.listdir(networkx_path):
#     G = deserialize(os.path.join(networkx_path, graph))
#     pyG = from_networkx(G, group_node_attrs=['emb'])
#     torch.save(pyG, os.path.join(save_path, graph.split('.')[0] + ".pt"))    

In [7]:
label_path = "/scr/biggest/gmachi/datasets/celldive_lung/processed/label_dict.obj"
label_dict = deserialize(label_path)


In [None]:
# os.listdir(save_path)

In [56]:
node_feats = "emb"
if node_feats == "emb":
    d = 512
elif node_feats == "concept":
    d = 1
elif node_feats == "raw":
    d = 34

In [57]:
class CellGraphData(Dataset):
    def __init__(self, data_path, label_dict):
        self.path = data_path
        self.names = os.listdir(self.path)
        self.label_dict = label_dict
    
    def len(self):
        return len(self.names)
    
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx, node_feats=node_feats):
        name = self.names[idx]
        g_nx = deserialize(os.path.join(self.path, name))
        graph = from_networkx(g_nx)
        
        if node_feats == "emb":
            graph.x = graph.emb
        elif node_feats == "concept":
            graph.x = graph.concept
        elif node_feats == "raw":
            graph.x = graph.raw
        # graph = torch.load(os.path.join(self.path, name))
        
        key_name = int(name.split('.')[0].split("S")[1])
        label = self.label_dict[key_name]
        return graph, float(label)
    
    def get(self, idx, node_feats=node_feats):
        name = self.names[idx]
        g_nx = deserialize(os.path.join(self.path, name))
        graph = from_networkx(g_nx)
        
        if node_feats == "emb":
            graph.x = graph.emb
        elif node_feats == "concept":
            graph.x = graph.concept
        elif node_feats == "raw":
            graph.x = graph.raw
        # graph = torch.load(os.path.join(self.path, name))
        
        key_name = int(name.split('.')[0].split("S")[1])
        label = self.label_dict[key_name]
        return graph, float(label)

In [58]:
class GAT(torch.nn.Module):
    def __init__(self, num_features, hidden_dim):
        super(GAT, self).__init__()
        self.conv1 = GATConv(num_features, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        # self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)


    def forward(self, x, edge_index, batch=None):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = global_mean_pool(x, batch)
        return self.fc(x).view(-1)

In [59]:
train_dataset = CellGraphData(networkx_path, label_dict)
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True)
    

In [60]:
device= "cpu"

model = GAT(d, 100).to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4, lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [61]:
def train_loop(model, dataloader, optimizer, criterion, device, epoch):
    model.train()
    for batch, y in tqdm(dataloader, desc='Epoch ' + str(epoch+1)):
        batch = batch.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

In [62]:
def eval(model, dataloader, criterion, device):
    model.eval()
    losses = []
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch, y in dataloader:
            batch = batch.to(device)
            y_true.extend(y.tolist())
            y = y.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out, y)
            y_pred.extend(torch.sigmoid(out).cpu().tolist())
            losses.append(loss.item())
    return np.mean(losses), y_true, y_pred

In [63]:
best_loss = np.inf
for epoch in range(10):
    train_loop(model, train_loader, optimizer, criterion, device, epoch)
    train_loss, _, _ = eval(model, train_loader, criterion, device)
    print(train_loss)

Epoch 1:   0%|          | 0/85 [00:00<?, ?it/s]

Epoch 1:   1%|          | 1/85 [00:11<16:16, 11.63s/it]


KeyboardInterrupt: 