In [121]:
import torch
from torch_geometric.nn.models import autoencoder
from torch_geometric.nn import GAE, VGAE, GCNConv
import torch.nn as nn
from torch_geometric.data import Data
import torch_geometric.transforms as T
import pandas as pd
import numpy as np
import pickle
import os
from tqdm import tqdm
from joblib import Parallel, delayed
import itertools
import matplotlib.pyplot as plt

In [122]:
torch.cuda.is_available()

True

In [123]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoder module

In [4]:
class GCNEncoder(torch.nn.Module):
  
  def __init__(self, in_channels, hidden_size, out_channels, dropout):
    super(GCNEncoder, self).__init__()
    self.conv1 = GCNConv(in_channels, hidden_size)
    self.conv2 = GCNConv(hidden_size, out_channels)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, edge_index):
    x = self.conv1(x, edge_index).relu()
    x = self.dropout(x)
    return self.conv2(x, edge_index)

In [5]:
def gae_train(train_data, gae_model, optimizer):
    gae_model.train()
    optimizer.zero_grad()
    z = gae_model.encode(train_data.x, train_data.edge_index)
    loss = gae_model.recon_loss(z, train_data.pos_edge_label_index.to(device))
    loss.backward(retain_graph=True)
    optimizer.step()
    return float(loss)

@torch.no_grad()
def gae_test(test_data, gae_model):
    gae_model.eval()
    z = gae_model.encode(test_data.x, test_data.edge_index)
    return gae_model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)

# Load data

In [25]:
proxi_df = pd.read_csv(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\proxi_dfs\cd4\set1.csv', index_col=0)

In [26]:
edges = pd.read_pickle(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\connected_patches\cd4\set1.pkl')

In [27]:
patch_centers = pd.read_csv(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\patch_centers\cd4\set1_centers.csv', index_col=0)

In [28]:
edge_temp = []
for item in tqdm(edges):
    edge_temp = edge_temp + item

total_connections = 0
for item in tqdm(edges):
    total_connections = total_connections + len(item)

edges = np.zeros((2,total_connections))
for i,item in tqdm(enumerate(edge_temp)):
    idx1 = proxi_df.index.tolist().index(item[0])
    idx2 = proxi_df.index.tolist().index(item[1])
    edges[0,i] = idx1
    edges[1,i] = idx2

100%|██████████| 229/229 [00:00<00:00, 4623.75it/s]
100%|██████████| 229/229 [00:00<00:00, 229180.53it/s]
61120it [00:46, 1328.61it/s]


In [29]:
X = proxi_df.values
X_tensor = torch.tensor(proxi_df.values, dtype=torch.float)
edge_tensor = torch.tensor(edges, dtype=torch.long)
data = Data(x=X_tensor,edge_index=edge_tensor)

In [11]:
# transformation
t = T.Compose([T.ToUndirected(),T.RandomLinkSplit(is_undirected=True,split_labels=True)])
train_set,val_set,test_set = t(data)

In [12]:
train_set.to(device)
val_set.to(device)
test_set.to(device)

Data(x=[12224, 15], edge_index=[2, 66832], pos_edge_label=[8354], pos_edge_label_index=[2, 8354], neg_edge_label=[8354], neg_edge_label_index=[2, 8354])

# training

In [13]:
NUM_FEATURES = X.shape[1]
HIDDEN_SIZE = 15
OUT_CHANNELS = 10

In [14]:
gae_model = GAE(GCNEncoder(NUM_FEATURES, HIDDEN_SIZE, OUT_CHANNELS, 0.5))
gae_model = gae_model.to(device)

In [15]:
EPOCHS = 1000

In [16]:
# Train
losses = []
test_auc = []
test_ap = []
train_aucs = []
train_aps = []

optimizer = torch.optim.Adam(gae_model.parameters(), lr=0.001)

for epoch in range(1, EPOCHS + 1):
    loss = gae_train(train_set, gae_model, optimizer)
    losses.append(loss)
    auc, ap = gae_test(test_set, gae_model)
    test_auc.append(auc)
    test_ap.append(ap)

    train_auc, train_ap = gae_test(train_set, gae_model)

    train_aucs.append(train_auc)
    train_aps.append(train_ap)

    print('Epoch: {:03d}, test AUC: {:.4f}, test AP: {:.4f}, train AUC: {:.4f}, train AP: {:.4f}, loss:{:.4f}'.format(epoch, auc, ap, train_auc, train_ap, loss))

Epoch: 001, test AUC: 0.6231, test AP: 0.6408, train AUC: 0.6721, train AP: 0.6914, loss:2.3292
Epoch: 002, test AUC: 0.6254, test AP: 0.6423, train AUC: 0.6744, train AP: 0.6929, loss:2.2526
Epoch: 003, test AUC: 0.6276, test AP: 0.6438, train AUC: 0.6766, train AP: 0.6945, loss:2.1923
Epoch: 004, test AUC: 0.6298, test AP: 0.6453, train AUC: 0.6789, train AP: 0.6961, loss:2.1510
Epoch: 005, test AUC: 0.6321, test AP: 0.6470, train AUC: 0.6811, train AP: 0.6976, loss:2.1259
Epoch: 006, test AUC: 0.6343, test AP: 0.6485, train AUC: 0.6834, train AP: 0.6992, loss:2.0838
Epoch: 007, test AUC: 0.6365, test AP: 0.6499, train AUC: 0.6856, train AP: 0.7007, loss:2.0197
Epoch: 008, test AUC: 0.6387, test AP: 0.6515, train AUC: 0.6878, train AP: 0.7022, loss:1.9774
Epoch: 009, test AUC: 0.6409, test AP: 0.6530, train AUC: 0.6900, train AP: 0.7037, loss:1.9026
Epoch: 010, test AUC: 0.6431, test AP: 0.6545, train AUC: 0.6922, train AP: 0.7052, loss:1.8939
Epoch: 011, test AUC: 0.6452, test AP: 0

In [17]:
out_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd4'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
torch.save(gae_model, r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd4\cd4_autoencoder_set1.pth')

In [18]:
model = torch.load(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd4\cd4_autoencoder_set1.pth')

In [30]:
t2 = T.Compose([T.ToUndirected()])
transformed_data = t2(data)
transformed_data.to(device)

Data(x=[12224, 15], edge_index=[2, 71316])

In [20]:
with torch.no_grad():
    z_embed = gae_model.encode(transformed_data.x, transformed_data.edge_index)

In [21]:
z_embed = z_embed.cpu()
numpy_z = z_embed.numpy()

In [22]:
numpy_z.shape

(12224, 10)

In [23]:
patch_centers

Unnamed: 0,row,col,z,patch_id,cellID
0,1052.462366,1384.064516,11.741935,000_1_0,000_1
1,1262.181818,1810.431818,22.363636,000_1_1,000_1
2,1364.704545,1427.886364,8.727273,000_1_2,000_1
3,1194.296296,1835.246914,11.740741,000_1_3,000_1
4,1342.506173,1526.691358,3.851852,000_1_4,000_1
...,...,...,...,...,...
12219,1442.291667,1604.791667,0.125000,029_6_48,029_6
12220,1299.375000,2017.750000,0.000000,029_6_49,029_6
12221,1403.652174,1611.478261,1.043478,029_6_50,029_6
12222,1313.894737,1677.684211,46.263158,029_6_51,029_6


In [24]:
z_df = pd.DataFrame(numpy_z, index=proxi_df.index)
z_df.insert(loc=z_df.shape[1], column='patch', value=proxi_df.index)
z_df.insert(loc=z_df.shape[1], column='cellID', value=patch_centers['cellID'])

In [25]:
z_df.to_csv(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd4\set1_proxi_embedding.csv')

# Multi-set training

In [149]:
sets = ['set3']

In [150]:
proxi_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\proxi_dfs\cd8'
edges_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\connected_patches\cd8'
centers_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\patch_centers\cd8'
out_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd8'
os.makedirs(out_dir, exist_ok=True)

HIDDEN_SIZE = 15
OUT_CHANNELS = 10
EPOCHS = 1000

for s in tqdm(sets):
    # Load data
    proxi_df = pd.read_csv(os.path.join(proxi_dir,s+'.csv'), index_col=0)
    edges = pd.read_pickle(os.path.join(edges_dir,s+'.pkl'))
    patch_centers = pd.read_csv(os.path.join(centers_dir,s+'_centers.csv'), index_col=0)
    
    # Read edges
    edge_temp = []
    for item in tqdm(edges):
        edge_temp = edge_temp + item

    total_connections = 0
    for item in tqdm(edges):
        total_connections = total_connections + len(item)

    edges = np.zeros((2,total_connections))
    for i,item in tqdm(enumerate(edge_temp)):
        idx1 = proxi_df.index.tolist().index(item[0])
        idx2 = proxi_df.index.tolist().index(item[1])
        edges[0,i] = idx1
        edges[1,i] = idx2
    
    # Create data object
    X = proxi_df.values
    X_tensor = torch.tensor(proxi_df.values, dtype=torch.float)
    edge_tensor = torch.tensor(edges, dtype=torch.long)
    data = Data(x=X_tensor,edge_index=edge_tensor)

    # transformation
    t = T.Compose([T.ToUndirected(),T.RandomLinkSplit(is_undirected=True,split_labels=True)])
    train_set,val_set,test_set = t(data)
    train_set.to(device)
    val_set.to(device)
    test_set.to(device)

    # Train
    NUM_FEATURES = X.shape[1]
    gae_model = GAE(GCNEncoder(NUM_FEATURES, HIDDEN_SIZE, OUT_CHANNELS, 0.5))
    gae_model = gae_model.to(device)

    losses = []
    test_auc = []
    test_ap = []
    train_aucs = []
    train_aps = []

    optimizer = torch.optim.Adam(gae_model.parameters(), lr=0.001)

    for epoch in range(1, EPOCHS + 1):
        loss = gae_train(train_set, gae_model, optimizer)
        losses.append(loss)
        auc, ap = gae_test(test_set, gae_model)
        test_auc.append(auc)
        test_ap.append(ap)

        train_auc, train_ap = gae_test(train_set, gae_model)

        train_aucs.append(train_auc)
        train_aps.append(train_ap)

        print('Epoch: {:03d}, test AUC: {:.4f}, test AP: {:.4f}, train AUC: {:.4f}, train AP: {:.4f}, loss:{:.4f}'.format(epoch, auc, ap, train_auc, train_ap, loss))

    torch.save(gae_model, os.path.join(out_dir,'cd8_autoencoder_'+s+'.pth'))
    
    # Embedding
    t2 = T.Compose([T.ToUndirected()])
    transformed_data = t2(data)
    transformed_data.to(device)
    with torch.no_grad():
        z_embed = gae_model.encode(transformed_data.x, transformed_data.edge_index)

    z_embed = z_embed.cpu()
    numpy_z = z_embed.numpy()
    z_df = pd.DataFrame(numpy_z, index=proxi_df.index)
    z_df.insert(loc=z_df.shape[1], column='patch', value=proxi_df.index)
    z_df.insert(loc=z_df.shape[1], column='cellID', value=patch_centers['cellID'])
    z_df.to_csv(os.path.join(out_dir,'cd8_'+s+'_proxi_embedding.csv'))

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 211/211 [00:00<00:00, 5023.20it/s]
100%|██████████| 211/211 [00:00<?, ?it/s]
47110it [00:28, 1626.07it/s]


Epoch: 001, test AUC: 0.6785, test AP: 0.6894, train AUC: 0.7560, train AP: 0.7572, loss:1.4436
Epoch: 002, test AUC: 0.6839, test AP: 0.6938, train AUC: 0.7616, train AP: 0.7617, loss:1.4398
Epoch: 003, test AUC: 0.6893, test AP: 0.6981, train AUC: 0.7673, train AP: 0.7663, loss:1.3963
Epoch: 004, test AUC: 0.6948, test AP: 0.7026, train AUC: 0.7729, train AP: 0.7709, loss:1.4031
Epoch: 005, test AUC: 0.7002, test AP: 0.7071, train AUC: 0.7786, train AP: 0.7755, loss:1.3653
Epoch: 006, test AUC: 0.7056, test AP: 0.7115, train AUC: 0.7841, train AP: 0.7802, loss:1.3627
Epoch: 007, test AUC: 0.7109, test AP: 0.7160, train AUC: 0.7896, train AP: 0.7848, loss:1.3273
Epoch: 008, test AUC: 0.7161, test AP: 0.7204, train AUC: 0.7950, train AP: 0.7893, loss:1.3218
Epoch: 009, test AUC: 0.7212, test AP: 0.7247, train AUC: 0.8002, train AP: 0.7938, loss:1.3057
Epoch: 010, test AUC: 0.7262, test AP: 0.7290, train AUC: 0.8052, train AP: 0.7981, loss:1.2794
Epoch: 011, test AUC: 0.7310, test AP: 0

100%|██████████| 1/1 [02:17<00:00, 137.07s/it]


# explanability

In [133]:
from torch_geometric.utils import remove_self_loops
from torch_geometric.explain import Explainer, GNNExplainer

In [151]:
sets = ['set3']

In [152]:
models_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd8'
models = {}
for s in sets:
    models[s] = torch.load(os.path.join(models_dir, 'cd8_autoencoder_' + s + '.pth'))

In [153]:
models['set3'].encoder

GCNEncoder(
  (conv1): GCNConv(15, 15)
  (conv2): GCNConv(15, 10)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [154]:
cd4_central_proxi = pd.read_pickle(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\cd4_central_proxi.pkl')
cd8_central_proxi = pd.read_pickle(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\cd8_central_proxi.pkl')

In [155]:
patch_centers = pd.read_csv(r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\patch_centers\cd8\set3_centers.csv', index_col=0)

In [156]:
proxi_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\proxi_dfs\cd8'
edge_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\connected_patches\cd8'
centers_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\networks\patch_centers\cd8'
out_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\gae_encoder_explainer\cd8'
t = T.Compose([T.ToUndirected()])
for s in sets:
    model = models[s]
    model.to(device)
    central_proxi = cd8_central_proxi[s]
    proxi_df = pd.read_csv(os.path.join(proxi_dir, s + '.csv'), index_col=0)
    patch_centers = pd.read_csv(os.path.join(centers_dir, s + '_centers.csv'), index_col=0)
    edges = pd.read_pickle(os.path.join(edge_dir, s + '.pkl'))

    proxi_df['patch_id'] = patch_centers['patch_id']
    proxi_df['cellID'] = patch_centers['cellID']
    cells = proxi_df['cellID'].unique().tolist()

    explainer = Explainer(model=model.encoder,
                        algorithm=GNNExplainer(epochs=200),
                        explanation_type='model',
                        node_mask_type='attributes',
                        edge_mask_type='object',
                        model_config=dict(mode='regression',
                                            task_level='graph',
                                            return_type='raw')
                        )

    for i in tqdm(central_proxi.index.tolist()):
        patch_id = central_proxi['patch_id'][i]
        cell = central_proxi['cellID'][i]

        cell_index = cells.index(cell)
        cell_x = proxi_df[proxi_df['cellID'] == cell]
        patch_index = cell_x['patch_id'].tolist().index(patch_id)
        cell_x = cell_x.drop(columns=['patch_id','cellID'])
        single_cell_x = torch.Tensor(cell_x.values.astype('float'))

        edge = np.array(edges[cell_index]).T
        min_idx = np.min(edge)
        edge = edge - min_idx
        cell_edge = torch.LongTensor(edge)

        cell_data = Data(x=single_cell_x, edge_index=cell_edge)
        single_cell_transformed_data = t(cell_data)
        single_cell_transformed_data.to(device)

        edge_index, edge_attr = remove_self_loops(single_cell_transformed_data.edge_index, single_cell_transformed_data.edge_attr)

        explanation = explainer(single_cell_transformed_data.x, edge_index, index = patch_index)

        edge_mask = explanation.edge_mask.detach().cpu().numpy()
        node_mask = explanation.node_mask.detach().cpu().numpy()

        explained_edges = edge_index.cpu().numpy()

        with open(os.path.join(out_dir,s + '_' + patch_id + '.pkl'),'wb') as f:
            pickle.dump({'edge_mask':edge_mask,'node_mask':node_mask,'explained_edges':explained_edges},f)

100%|██████████| 1055/1055 [36:08<00:00,  2.06s/it]
