In [138]:
import torch
from torch_geometric.nn.models import autoencoder
from torch_geometric.nn import GAE, VGAE, GCNConv
import torch.nn as nn
from torch_geometric.data import Data
import torch_geometric.transforms as T
import pandas as pd
import numpy as np
import pickle
import os
from tqdm import tqdm
from joblib import Parallel, delayed
import itertools
import matplotlib.pyplot as plt

In [139]:
torch.cuda.is_available()

True

In [140]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoder module

In [141]:
class GCNEncoder(torch.nn.Module):
  
  def __init__(self, in_channels, hidden_size, out_channels, dropout):
    super(GCNEncoder, self).__init__()
    self.conv1 = GCNConv(in_channels, hidden_size)
    self.conv2 = GCNConv(hidden_size, out_channels)
    self.dropout = nn.Dropout(dropout)

  # Our model will take the feature matrix X and the edge list
  # representation of the graph as inputs.
  def forward(self, x, edge_index):
    x = self.conv1(x, edge_index).relu()
    x = self.dropout(x)
    return self.conv2(x, edge_index)

In [142]:
def gae_train(train_data, gae_model, optimizer):
    gae_model.train()
    optimizer.zero_grad()
    z = gae_model.encode(train_data.x, train_data.edge_index)
    loss = gae_model.recon_loss(z, train_data.pos_edge_label_index.to(device))
    loss.backward(retain_graph=True)
    optimizer.step()
    return float(loss)

@torch.no_grad()
def gae_test(test_data, gae_model):
    gae_model.eval()
    z = gae_model.encode(test_data.x, test_data.edge_index)
    return gae_model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)

# Load data, combine into a single-graph

In [260]:
in_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\gae_encoder_explainer\cd8\detected_motifs'
motif_l = os.listdir(in_dir)
motif_l.sort()

In [261]:
sets = ['set3']

models_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\embedded_features\cd8'
models = {}
for s in sets:
    models[s] = torch.load(os.path.join(models_dir, f'cd8_autoencoder_{s}.pth'))

In [262]:
for k in models.keys():
    models[k].to(device)

In [263]:
temp = pd.read_pickle(os.path.join(in_dir, motif_l[0]))

In [264]:
x = torch.Tensor(temp['motif_features'].values[:,:-2].astype('float'))
edge_index = torch.LongTensor(np.array(temp['motif_edges']).T)

data = Data(x=x, edge_index=edge_index)
t2 = T.Compose([T.ToUndirected()])
transformed_data = t2(data)
transformed_data.to(device)

Data(x=[16, 15], edge_index=[2, 62])

In [265]:
with torch.no_grad():
    z_embed = models['set3'].encode(transformed_data.x, transformed_data.edge_index)

In [267]:
data_sets = {'set3':{'x':[], 'edge_index':[]}}
motif_id = []
patch_id = []
for motif in tqdm(motif_l):
    s = motif.split('_')[0]
    if s not in data_sets.keys():
        continue
    data = pd.read_pickle(os.path.join(in_dir, motif))
    data_sets[s]['x'].append(data['motif_features'])
    data_sets[s]['edge_index'].append(data['motif_edges'])
    if len(data['motif_edges']) == 0:
        print(motif)

    motif_id = motif_id + [motif.split('.')[0]] * data['motif_features'].shape[0]
    patch_id = patch_id + data['motif_patches']['patch_id'].tolist()

  0%|          | 0/3188 [00:00<?, ?it/s]

100%|██████████| 3188/3188 [00:05<00:00, 586.46it/s]  


In [268]:
data_sets_combined = {}
for k in data_sets.keys():
    x_temp = []
    edge_temp = []
    curr_edge_shape = 0
    for i in range(len(data_sets[k]['x'])):
        x_temp.append(data_sets[k]['x'][i].values[:,:-2].astype('float'))
        edge_temp.append(np.array(data_sets[k]['edge_index'][i]).T + curr_edge_shape)
        if np.array(data_sets[k]['edge_index'][i]).T.shape[0] == 0:
            print(i)
        curr_edge_shape += data_sets[k]['x'][i].values[:,:-2].astype('float').shape[0]
    data_sets_combined[k] = {'x':np.vstack(x_temp), 'edge_index':np.hstack(edge_temp)}

In [269]:
embedding = {}
for k in data_sets.keys():
    with torch.no_grad():
        z_embed = models[k].encode(torch.Tensor(data_sets_combined[k]['x']).to(device),
                                   torch.LongTensor(data_sets_combined[k]['edge_index']).to(device))
        embedding[k] = z_embed.cpu().numpy()

In [270]:
out_dir = r'..\..\coculture_diagonal\primed_pbmc\00_analysis\gae_encoder_explainer\cd8\motif_embedding'
curr_shape = 0
for k in data_sets.keys():
    motif_embedding = pd.DataFrame(embedding[k], columns=[f'emb_{i}' for i in range(embedding[k].shape[1])])
    motif_embedding['motif_id'] = motif_id[curr_shape:curr_shape + embedding[k].shape[0]]
    motif_embedding['patch_id'] = patch_id[curr_shape:curr_shape + embedding[k].shape[0]]
    curr_shape = curr_shape + embedding[k].shape[0]
    motif_embedding.to_csv(os.path.join(out_dir, f'{k}_motif_embedding.csv'), index=False)