In [1]:
import pandas as pd
import numpy as np
import os
import pickle
import networkx as nx
import mygene
import torch
import math
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn import GAE
from torch_geometric.data import Data, DataLoader

In [3]:
from typing import Type

import torch
import torch_geometric

from torch.nn import Linear, ReLU,Dropout
from torch_geometric.nn import Sequential, GCNConv, TopKPooling
import torch.nn.functional as F
import torch.nn as nn


class GNNExample(nn.Module):
    def __init__(self, num_features, input_dim, num_samples, L, batch_size):
        super(GNNExample, self).__init__()
        self.conv = GCNConv(num_features, num_features)
        self.batch_size = batch_size
        self.input_dim = input_dim
        self.num_features = num_features
        self.num_samples = num_samples

    def convolute(self, data):
        xs = torch.tensor([])
        for i in range(len(data)):
            x, edge_index = data[i].x, data[i].edge_index
            h = self.conv(x, edge_index)
            h = h.sigmoid()
            xs = torch.cat([xs, h])
        return xs

    def forward(self, data):
        xs = self.convolute(data)
        xs = torch.reshape(xs, (self.num_samples, self.input_dim))
        return xs

    def get_latent_space(self, data):
        xs = self.convolute(data)
        xs = torch.reshape(xs, (self.batch_size, self.input_dim))
        return xs


In [107]:
from typing import Type

import torch


class MWE_AE(torch.nn.Module):

    def __init__(self, input_dim, L):
        super().__init__()

        print("Initializing Minimal Working Example AE with input dim: ", input_dim)

        self.encoder = torch.nn.Sequential(
            custom_block(input_dim, 2500),
            custom_block(2500, 2000),
            custom_block(2000, 1500),
            custom_block(1500, 1200),
            custom_block(1200, 1000),
            custom_block(1000, 800),
            custom_block(800, 600),
            custom_block(600, 400),
            custom_block(400, 300),
            custom_block(300, 150),
            custom_block(150, 100),
            custom_block(100, 50),
            custom_block(50, L),
            torch.nn.Sigmoid()
        )

        self.decoder = torch.nn.Sequential(
            custom_block(L, 50),
            custom_block(50, 100),
            custom_block(100, 150),
            custom_block(150, 300),
            custom_block(300, 400),
            custom_block(400, 600),
            custom_block(600, 800),
            custom_block(800, 1000),
            custom_block(1000, 1200),
            custom_block(1200, 1500),
            custom_block(1500, 2000),
            custom_block(2000, 2500),
            custom_block(2500, input_dim),
            torch.nn.BatchNorm1d(input_dim),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)

        decoded = self.decoder(encoded)
        return decoded

    def get_latent_space(self, x):
        return self.encoder(x)


def custom_block(input_dim, output_dim, dropout_rate=0.4):
    return torch.nn.Sequential(
        torch.nn.Linear(input_dim, output_dim),
        torch.nn.BatchNorm1d(output_dim),
        torch.nn.PReLU(),
        torch.nn.Dropout(dropout_rate)
    )


In [82]:
class CustomDataset():
    """
    This class is used to have all types of data in one place. For example, the entire train set can be housed
    within this class. This way when we need to merge genData and cliData together, it can be done easily, as well
    as checking the labels for later use.
    """
    def __init__(self, genData, cliData, labels):
        self.genData = genData
        self.cliData = cliData
        self.labels = labels

    def __len__(self):
        return len(self.genData)

    def __getitem__(self, idx):
        return self.genData[idx], self.cliData[idx], self.labels[idx]

### TABULAR SIDE

In [222]:
current_directory = os.getcwd()

somepath = os.path.abspath(
    os.path.join(current_directory, '..', 'Data', 'RNA_dataset_tabular_R3.csv'))

# expression data
tabular_data = pd.read_csv(somepath, sep = ',', index_col = 0)
tabular_data = tabular_data.astype('float32')
gene_data = tabular_data
gene_data

Unnamed: 0_level_0,NFKB1,TNIP2,AMOT,VASP,SS18L1,SMARCA4,SMURF1,HSPA5,SKIL,UBE2I,...,SULF2,CD276,TIPARP,PFS_P,PFS_P_CNSR,MATH,HE_TUMOR_CELL_CONTENT_IN_TUMOR_AREA,PD-L1_TOTAL_IMMUNE_CELLS_PER_TUMOR_AREA,CD8_POSITIVE_CELLS_TUMOR_CENTER,CD8_POSITIVE_CELLS_TOTAL_AREA
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
X00936b9285d6b8665ae9122993fb8e91,6.10,4.38,3.07,6.14,4.70,7.52,4.93,8.37,6.21,5.69,...,8.68,7.09,4.94,4.172484,0.0,17.928391,70.0,0.0,0.08,0.1931
X105622fadc33f23755ac2df823110aca,5.07,3.33,1.73,5.11,5.62,6.58,4.34,7.42,6.15,4.86,...,6.08,6.56,4.61,16.591375,1.0,16.122089,85.0,1.0,0.12,0.1214
Xe44f39747a8e84b02b4cb24659312144,6.13,4.41,3.23,6.32,5.57,8.02,5.14,7.55,6.87,6.27,...,6.33,7.14,8.42,11.104723,0.0,23.616636,80.0,5.0,0.92,0.9203
X293dd1284496215e9a0eca9f17a98e7e,5.82,4.30,3.44,6.45,4.86,7.45,4.90,8.39,6.83,5.70,...,6.97,6.73,6.28,14.028748,1.0,24.817434,60.0,5.0,3.16,3.1635
X01ed7190ce00862696edbf047b542045,6.15,4.21,3.90,5.93,4.43,7.60,4.74,8.31,6.38,5.94,...,5.67,6.66,4.93,12.418891,0.0,19.303864,80.0,2.0,1.98,2.0708
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Xc3d410d70dd7359baa40126494fb6765,6.25,3.62,4.80,4.72,5.39,7.01,4.57,6.96,6.26,5.70,...,2.40,6.58,6.59,9.790554,1.0,24.552612,75.0,0.0,1.01,1.0089
X50772aa64efb859960b20f8801cd6f58,6.27,3.78,3.98,5.95,4.66,7.18,4.61,8.04,6.62,5.80,...,5.92,6.74,5.48,4.271047,0.0,15.672304,75.0,1.0,1.10,1.1775
X91bcd3067a1a7954692d836515e04869,6.12,3.94,3.25,6.09,4.98,7.32,5.02,8.21,6.71,5.89,...,7.84,7.59,7.59,2.496920,0.0,27.837849,50.0,1.0,4.03,3.9642
Xc7439a06ffa32b313b0ec1b987b992a2,5.91,3.62,3.34,5.80,8.86,8.69,4.52,8.15,5.77,6.85,...,6.81,6.82,5.27,6.505134,1.0,26.606825,80.0,1.0,0.14,0.1417


In [223]:
cli_vars = ['PFS_P', 'PFS_P_CNSR', 'MATH', 'HE_TUMOR_CELL_CONTENT_IN_TUMOR_AREA', 'PD-L1_TOTAL_IMMUNE_CELLS_PER_TUMOR_AREA', 'CD8_POSITIVE_CELLS_TUMOR_CENTER', 'CD8_POSITIVE_CELLS_TOTAL_AREA']
gene_data = tabular_data.drop(columns = cli_vars)

In [224]:
maxVal = max([x for L in gene_data.values for x in L])
X_normalized = gene_data / maxVal

In [225]:
X_normalized

Unnamed: 0_level_0,NFKB1,TNIP2,AMOT,VASP,SS18L1,SMARCA4,SMURF1,HSPA5,SKIL,UBE2I,...,SLC22A3,SPAG16,HTATIP2,SLC17A1,MGST2,CHPT1,STK17A,SULF2,CD276,TIPARP
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
X00936b9285d6b8665ae9122993fb8e91,0.395078,0.283679,0.198834,0.397668,0.304404,0.487047,0.319301,0.542098,0.402202,0.368523,...,0.163860,0.281736,0.288212,0.289508,0.457902,0.338731,0.264249,0.562176,0.459197,0.319948
X105622fadc33f23755ac2df823110aca,0.328368,0.215674,0.112047,0.330959,0.363990,0.426166,0.281088,0.480570,0.398316,0.314767,...,0.371762,0.287565,0.308290,0.288212,0.479275,0.443005,0.183938,0.393782,0.424870,0.298575
Xe44f39747a8e84b02b4cb24659312144,0.397021,0.285622,0.209197,0.409326,0.360751,0.519430,0.332902,0.488990,0.444948,0.406088,...,0.300518,0.264249,0.299870,0.190415,0.431347,0.376295,0.269430,0.409974,0.462435,0.545337
X293dd1284496215e9a0eca9f17a98e7e,0.376943,0.278497,0.222798,0.417746,0.314767,0.482513,0.317358,0.543394,0.442358,0.369171,...,0.292746,0.306995,0.318653,0.053756,0.431995,0.341321,0.266839,0.451425,0.435881,0.406736
X01ed7190ce00862696edbf047b542045,0.398316,0.272668,0.252591,0.384067,0.286917,0.492228,0.306995,0.538212,0.413212,0.384715,...,0.127591,0.347798,0.361399,0.433938,0.422927,0.378238,0.281736,0.367228,0.431347,0.319301
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Xc3d410d70dd7359baa40126494fb6765,0.404793,0.234456,0.310881,0.305699,0.349093,0.454016,0.295984,0.450777,0.405440,0.369171,...,0.370466,0.457902,0.306995,0.448187,0.367228,0.488990,0.150259,0.155440,0.426166,0.426813
X50772aa64efb859960b20f8801cd6f58,0.406088,0.244819,0.257772,0.385363,0.301813,0.465026,0.298575,0.520725,0.428756,0.375648,...,0.312824,0.372409,0.360751,0.505829,0.470207,0.398316,0.227979,0.383420,0.436529,0.354922
X91bcd3067a1a7954692d836515e04869,0.396373,0.255181,0.210492,0.394430,0.322539,0.474093,0.325130,0.531736,0.434586,0.381477,...,0.199482,0.313472,0.338731,0.000648,0.413212,0.336788,0.262306,0.507772,0.491580,0.491580
Xc7439a06ffa32b313b0ec1b987b992a2,0.382772,0.234456,0.216321,0.375648,0.573834,0.562824,0.292746,0.527850,0.373705,0.443653,...,0.060881,0.389896,0.261658,0.473446,0.437824,0.461140,0.210492,0.441062,0.441710,0.341321


In [229]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

torch_tensor = torch.tensor(X_normalized.values)

data = CustomDataset(torch_tensor, [1] * len(torch_tensor), [2] * len(torch_tensor))

batch_size = 64
train_loader = list(DataLoader(data, batch_size=batch_size, shuffle=False))

In [230]:
X = list(train_loader)
res = torch.tensor([])
for x in X:
    res = torch.cat((res, x[2]), dim=0)
res

tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 

In [139]:
model_tab = MWE_AE(2866, 50)
lat = model_tab(res)
lat

Initializing Minimal Working Example AE with input dim:  2866


tensor([[0.4127, 0.3174, 0.9258,  ..., 0.3901, 0.5520, 0.3169],
        [0.4575, 0.3279, 0.6706,  ..., 0.9447, 0.7058, 0.4176],
        [0.4127, 0.3884, 0.6172,  ..., 0.2966, 0.3303, 0.4176],
        ...,
        [0.3089, 0.4147, 0.4174,  ..., 0.7215, 0.4200, 0.9184],
        [0.4127, 0.4273, 0.4162,  ..., 0.7860, 0.4790, 0.3998],
        [0.7091, 0.5397, 0.3998,  ..., 0.4216, 0.2475, 0.4176]],
       grad_fn=<SigmoidBackward0>)

-----------------

-----------------

-----------------

### GRAPH SIDE

In [140]:
import os

current_directory = os.getcwd()

somepath = os.path.abspath(
    os.path.join(current_directory, '..', 'Data', 'RNA_dataset_graph_R3.pkl'))

with open(somepath, 'rb') as f:
    loaded_object = pickle.load(f)

In [141]:
G = loaded_object[0]
features = []
for node, attr in G.nodes(data = True):
    features += [attr['node_attr']]
features = torch.tensor(features)
features

tensor([6.1000, 4.3800, 3.0700,  ..., 8.6800, 7.0900, 4.9400],
       dtype=torch.float64)

In [142]:
from torch_geometric.data import Data

data = None

G = loaded_object[0]
# we enumerate each node in a dict
node_to_index = {node: idx for idx, node in enumerate(G.nodes())}

edge_index = torch.tensor([(node_to_index[edge[0]], node_to_index[edge[1]]) for edge in G.edges()] +
                 [(node_to_index[edge[1]], node_to_index[edge[0]]) for edge in G.edges()]).t().contiguous()
data = Data(x= features, edge_index = edge_index)
data.validate(raise_on_error=True)
data

Data(x=[2866], edge_index=[2, 90932])

In [143]:
def collect_all_graph_data(graphs):
    D = []
    # edges are the same for all graphs so we only need to compute this once.
    G = graphs[0]
    node_to_index = {node: idx for idx, node in enumerate(G.nodes())}

    edge_index = torch.tensor([(node_to_index[edge[0]], node_to_index[edge[1]]) for edge in G.edges()] +
                 [(node_to_index[edge[1]], node_to_index[edge[0]]) for edge in G.edges()]).t().contiguous()
    
    for g in graphs:
        features = []
        for node, attr in g.nodes(data = True):
            features += [[float(attr['node_attr'])]]
        features = torch.tensor(features)
        d = Data(x = features, edge_index = edge_index)
        d.validate(raise_on_error=True)
        D += [d]
    return D
        

In [144]:
X = collect_all_graph_data(loaded_object)

In [145]:
from torch_geometric.loader import DataLoader

custom_dataset = CustomDataset(X, [1] * len(X), [2] * len(X))
    
loader = list(DataLoader(custom_dataset, batch_size=64, shuffle=False))
loader

[[DataBatch(x=[183424, 1], edge_index=[2, 5819648], batch=[183424], ptr=[65]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])],
 [DataBatch(x=[183424, 1], edge_index=[2, 5819648], batch=[183424], ptr=[65]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
     

In [218]:
X = list(loader)
dim = 0
res = []
for x in X:
    res += [x[dim]]
z = list(DataLoader(res, batch_size = len(res)))

In [232]:
z[0]

DataBatch(x=[1862900, 1], edge_index=[2, 59105800], batch=[1862900])

In [221]:
model = GNNExample(1, 2866, 650, 50, 650)
lat = model(z)
lat

tensor([[0.9231, 0.7012, 0.8368,  ..., 0.7934, 0.7635, 0.6901],
        [0.9042, 0.6851, 0.8098,  ..., 0.7188, 0.7471, 0.6788],
        [0.9257, 0.7095, 0.8421,  ..., 0.7359, 0.7640, 0.7821],
        ...,
        [0.9251, 0.7017, 0.8335,  ..., 0.7628, 0.7763, 0.7630],
        [0.9216, 0.6988, 0.8443,  ..., 0.7445, 0.7536, 0.6975],
        [0.9272, 0.7061, 0.8407,  ..., 0.7397, 0.7848, 0.7020]],
       grad_fn=<ViewBackward0>)

In [94]:
len(lat)

650

-------------------

## PROOF THAT DATA LOADER CAN ACTUALLY HANDLE DATA BATCHES!!

The problem here was that if we already have a dataloader with batches of size 32, if we were to unroll these batches to make a bigger dataloader without batches (so all data in one batch), how would this be done? Luckily, the library accounts for that.

In [213]:
import torch
import torch_geometric.data as data

# Define the edges for the first graph
edges1 = torch.tensor([[0,1], [1,0]])


# Create data objects for each graph
graph1 = data.Data(edge_index=edges1)
graph2 = data.Data(edge_index=edges1)
graph3 = data.Data(edge_index=edges1)

# Print the graphs
print("Graph 1:")
print(graph1)
print("Graph 2:")
print(graph2)
print("Graph 3:")
print(graph3)

x = [graph1, graph2, graph3]

print("First data loader")
d = DataLoader(x, batch_size = 2)

xs = []
for i in d:
    print(i)
    xs +=[i]

print("Second data loader")
d2 = DataLoader(xs, batch_size = 3)
for i in d2:
    print(i)
    print(i.edge_index)

Graph 1:
Data(edge_index=[2, 2])
Graph 2:
Data(edge_index=[2, 2])
Graph 3:
Data(edge_index=[2, 2])
First data loader
DataBatch(edge_index=[2, 4], batch=[4], ptr=[3])
DataBatch(edge_index=[2, 2], batch=[2], ptr=[2])
Second data loader
DataBatch(edge_index=[2, 6], batch=[6])
tensor([[0, 1, 2, 3, 4, 5],
        [1, 0, 3, 2, 5, 4]])


------------------
---------------------
---------------------

In [5]:
import torch
from torch_geometric.nn import GCNConv, SimpleConv
from torch_geometric.data import Data

# adjacency matrix of graph
edge_index = torch.tensor([[0,1,1,2], [1,0,2,1]], dtype=torch.long)

# features for each node
x = torch.tensor([[1], [2], [3]], dtype=torch.float)

# simpleconv declaration
in_channels = 1  # Number of input features per node
out_channels = 1  # Number of output features per node
conv = GCNConv(in_channels, out_channels)

# process it
x = conv(x, edge_index)

print(x)


tensor([[2.0129],
        [3.5161],
        [3.5418]], grad_fn=<AddBackward0>)


In [3]:
import os
import pickle 
import torch

current_directory = os.getcwd()

somepath = os.path.abspath(
    os.path.join(current_directory, '..', 'Data', 'RNA_dataset_graph_R3.pkl'))

with open(somepath, 'rb') as f:
    loaded_object = pickle.load(f)

In [4]:
G = loaded_object[0]
features = []
for node, attr in G.nodes(data = True):
    features += [[attr['node_attr']]]
features = torch.tensor(features)
features

tensor([[6.1000],
        [4.3800],
        [3.0700],
        ...,
        [8.6800],
        [7.0900],
        [4.9400]], dtype=torch.float64)

In [28]:
from torch_geometric.data import Data

data = None

G = loaded_object[0]
# we enumerate each node in a dict
node_to_index = {node: idx for idx, node in enumerate(G.nodes())}

edge_index = torch.tensor([(node_to_index[edge[0]], node_to_index[edge[1]]) for edge in G.edges()] +
                 [(node_to_index[edge[1]], node_to_index[edge[0]]) for edge in G.edges()]).t().contiguous()
data = Data(x= features, edge_index = edge_index)
data.validate(raise_on_error=True)
data

Data(x=[2866, 1], edge_index=[2, 90932])

In [29]:
data

Data(x=[2866, 1], edge_index=[2, 90932])

In [30]:
xx = data.x / max(data.x)

In [54]:
xx

tensor([[0.4743],
        [0.3406],
        [0.2387],
        ...,
        [0.6750],
        [0.5513],
        [0.3841]], dtype=torch.float64)

In [46]:
recons_matrix = torch.matmul(xx, xx.t())
recons_matrix

tensor([[0.2250, 0.1616, 0.1132,  ..., 0.3202, 0.2615, 0.1822],
        [0.1616, 0.1160, 0.0813,  ..., 0.2299, 0.1878, 0.1308],
        [0.1132, 0.0813, 0.0570,  ..., 0.1611, 0.1316, 0.0917],
        ...,
        [0.3202, 0.2299, 0.1611,  ..., 0.4556, 0.3721, 0.2593],
        [0.2615, 0.1878, 0.1316,  ..., 0.3721, 0.3040, 0.2118],
        [0.1822, 0.1308, 0.0917,  ..., 0.2593, 0.2118, 0.1476]],
       dtype=torch.float64)

In [36]:
data.edge_index[1]

tensor([   1,  282,  554,  ..., 2702, 2706, 2708])

In [40]:
edges = [(data.edge_index[0][i].item(), data.edge_index[1][i].item()) for i in range(len(data.edge_index[0]))]
edges

[(0, 1),
 (0, 282),
 (0, 554),
 (0, 665),
 (0, 29),
 (0, 1213),
 (0, 1102),
 (0, 366),
 (0, 145),
 (0, 67),
 (0, 1014),
 (0, 1134),
 (0, 523),
 (0, 1163),
 (0, 213),
 (0, 640),
 (0, 452),
 (0, 277),
 (0, 963),
 (0, 152),
 (0, 512),
 (0, 479),
 (0, 409),
 (0, 1669),
 (0, 545),
 (0, 693),
 (0, 830),
 (0, 669),
 (0, 308),
 (0, 1432),
 (0, 270),
 (0, 86),
 (0, 820),
 (0, 684),
 (0, 293),
 (0, 1126),
 (0, 392),
 (0, 135),
 (0, 1433),
 (0, 1827),
 (0, 1016),
 (0, 1064),
 (0, 924),
 (0, 240),
 (0, 1291),
 (0, 54),
 (0, 142),
 (0, 42),
 (0, 95),
 (0, 623),
 (0, 1274),
 (0, 651),
 (0, 1476),
 (0, 388),
 (0, 620),
 (0, 1891),
 (0, 1118),
 (0, 890),
 (0, 1235),
 (0, 1399),
 (0, 526),
 (0, 1680),
 (0, 126),
 (0, 7),
 (0, 214),
 (0, 197),
 (0, 389),
 (0, 348),
 (0, 472),
 (0, 1574),
 (0, 1690),
 (0, 511),
 (0, 1105),
 (0, 1089),
 (0, 296),
 (0, 547),
 (0, 99),
 (0, 246),
 (0, 1221),
 (0, 438),
 (0, 426),
 (0, 1369),
 (0, 995),
 (0, 1212),
 (0, 807),
 (0, 2823),
 (0, 586),
 (0, 801),
 (0, 1234),
 (0

In [44]:
adjMatrix = torch.zeros(len(xx), len(xx))
adjMatrix

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [47]:
for i in edges:
    adjMatrix[i[0]][i[1]] = 1
adjMatrix

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [56]:
crit = nn.CrossEntropyLoss(reduction='mean')
crit(recons_matrix, adjMatrix)

tensor(252.0417, dtype=torch.float64)

In [87]:
latent_space.shape[0]

100

In [62]:
latent_space = torch.rand(100,500)
latent_space

tensor([[1.2105e-01, 5.0318e-01, 6.5654e-01,  ..., 9.9106e-01, 1.2716e-01,
         8.4550e-01],
        [2.8750e-01, 5.1376e-02, 1.7650e-01,  ..., 6.2799e-02, 8.7049e-01,
         2.3504e-02],
        [9.5096e-01, 7.1031e-01, 6.5141e-01,  ..., 4.1342e-01, 1.5894e-01,
         4.6340e-01],
        ...,
        [6.7209e-01, 7.9483e-01, 4.2482e-02,  ..., 7.3005e-01, 7.0396e-01,
         9.0176e-01],
        [4.1348e-01, 1.8075e-01, 4.0985e-01,  ..., 4.1981e-01, 1.3872e-01,
         3.6643e-01],
        [6.8927e-04, 3.7078e-01, 8.3436e-01,  ..., 7.5807e-01, 8.7345e-01,
         1.9088e-01]])

In [95]:
import xlsxwriter

row = 1
col = 1

workbook = xlsxwriter.Workbook('Latent_space.xlsx')
worksheet = workbook.add_worksheet()

bold_format = workbook.add_format({'bold': True, 'bg_color': '#DDDDDD'})

for i in range(latent_space.shape[1]):
    worksheet.write(0, i + 1, 'Lat. '+str(i), bold_format)

for i in range(latent_space.shape[0]):
    worksheet.write(i + 1, 0, 'Pat. '+str(i), bold_format)
    

for xs in latent_space:
    for x in xs:
        red_component = max(int(x * 255), 30)
        blue_component = max(int((1 - x) * 255), 30)
        purple_hex = '#{:02X}30{:02X}'.format(red_component, blue_component)
        purple_format = workbook.add_format({'bg_color': purple_hex, 'font_color': 'white'})
        worksheet.write(row, col, round(x.item(), 2), purple_format)
        col += 1
    row += 1
    col = 1
workbook.close()