In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pandas as pd
import time

from model import GITIII,Loss_function
from calculate_PCC import Calculate_PCC

from dataloader import AD_evaluate_dataset

to_save_dir="../edges/"
data_dir="../../data/AD/processed1/"

ligands_info = torch.load("/".join(data_dir.split("/")[:-2]) + "/ligands.pth")
genes = torch.load("/".join(data_dir.split("/")[:-2]) + "/genes.pth")

my_model = GITIII(genes, ligands_info, node_dim=256, edge_dim=48, num_heads=2, n_layers=1, node_dim_small=16,att_dim=8)
my_model = my_model.cuda()

my_model.load_state_dict(torch.load("GRIT_best.pth"))

loss_func = Loss_function(genes, ligands_info).cuda()
evaluator=Calculate_PCC(genes,ligands_info)


cell_types_dict = {}
cnt = 0
for cell_typei in ['Astrocyte', 'Chandelier', 'Endothelial', 'L2/3 IT', 'L4 IT', 'L5 ET', 'L5 IT', 'L5/6 NP', 'L6 CT', 'L6 IT', 'L6 IT Car3', 'L6b', 'Lamp5', 'Lamp5 Lhx6', 'Microglia-PVM', 'OPC', 'Oligodendrocyte', 'Pax6', 'Pvalb', 'Sncg', 'Sst', 'Sst Chodl', 'VLMC', 'Vip']:
    cell_types_dict[cnt] = cell_typei
    cnt += 1

def evaluate_AD(sample):
    my_dataset = AD_evaluate_dataset(processed_dir=data_dir,sample=sample)
    my_dataloader = DataLoader(my_dataset, batch_size=1, num_workers=0, shuffle=False)

    length = len(my_dataloader)

    my_model.eval()

    results = []
    with torch.no_grad():
        for (stepi, x) in enumerate(my_dataloader, start=1):
            x = {k: v.cuda() for k, v in x.items()}

            cell_type_name = [cell_types_dict[int(i.cpu())] for i in x["cell_types"].squeeze(dim=0)]

            y_pred = my_model(x)
            y = x["y"]
            lossi1, lossi2 = loss_func(y_pred, y)
            lossi1 = lossi1.cpu().detach()
            lossi2 = lossi2.cpu().detach()

            attention_score = y_pred[1][0].cpu().detach()
            attention_score = attention_score.squeeze(dim=0)
            attention_score = attention_score.permute(1, 2, 0)[0, :, :]
            # print(torch.topk(attention_score[0, :], k=30, dim=-1))

            edges = y_pred[1][1].cpu().detach()
            edges = edges.squeeze(dim=0).permute(1, 2, 0)[0, :, :]

            position_x = x["position_x"].cpu().detach().squeeze(dim=0)
            position_y = x["position_y"].cpu().detach().squeeze(dim=0)

            to_save_dict = {
                "edges": edges,
                "attention_score": attention_score,
                "position_x": position_x,
                "position_y": position_y,
                "cell_type_name": cell_type_name,
                "loss_all": lossi1,
                "loss_no_interact": lossi2,
                "y_pred": y_pred[0].cpu().detach().squeeze(dim=0),
                "y": y.cpu().detach().squeeze(dim=0)
            }
            results.append(to_save_dict)

            if stepi % 2000 == 0:
                print(stepi, "/", length)
    concatenated_results = {}
    for keyi in results[0].keys():
        if keyi not in ["cell_type_name"]:
            concatenated_results[keyi] = torch.stack([results[j][keyi] for j in range(len(results))], dim=0)
        else:
            concatenated_results[keyi] = [results[j][keyi] for j in range(len(results))]
    torch.save(concatenated_results, to_save_dir + "edges_" + sample + ".pth")

    print("Finish", sample)
    return concatenated_results


if __name__ == "__main__":
    sample1 = evaluate_AD('H20.33.001.CX28.MTG.02.007.1.02.02')
    for keyi in sample1.keys():
        if keyi not in ["cell_type_name"]:
            print(keyi, sample1[keyi].shape)
        else:
            print(keyi, len(sample1[keyi]), len(sample1[keyi][0]))

Have samples: ['H20.33.001.CX28.MTG.02.007.1.02.02']
There are totally 7937 cells in this dataset
2000 / 7937
4000 / 7937
6000 / 7937
Finish H20.33.001.CX28.MTG.02.007.1.02.02
edges torch.Size([7937, 49, 140])
attention_score torch.Size([7937, 49, 140])
position_x torch.Size([7937, 50])
position_y torch.Size([7937, 50])
cell_type_name 7937 50
loss_all torch.Size([7937])
loss_no_interact torch.Size([7937])
y_pred torch.Size([7937, 140])
y torch.Size([7937, 140])


In [2]:
samples=[]
data_dir="../../data/AD/processed1/"
for filei in os.listdir(data_dir):
    if filei.find("_TypeExp.npz") >= 0:
        samples.append(filei.split("_TypeExp.npz")[0])

for samplei in samples:
    evaluate_AD(samplei)

Have samples: ['H20.33.004.Cx26.MTG.02.007.1.02.04']
There are totally 6594 cells in this dataset
2000 / 6594
4000 / 6594
6000 / 6594
Finish H20.33.004.Cx26.MTG.02.007.1.02.04
Have samples: ['H20.33.004.Cx26.MTG.02.007.1.01.04']
There are totally 5323 cells in this dataset
2000 / 5323
4000 / 5323
Finish H20.33.004.Cx26.MTG.02.007.1.01.04
Have samples: ['H20.33.004.Cx26.MTG.02.007.1.01.05']
There are totally 3157 cells in this dataset
2000 / 3157
Finish H20.33.004.Cx26.MTG.02.007.1.01.05
Have samples: ['H21.33.011.Cx26.MTG.02.007.3.01.06']
There are totally 5406 cells in this dataset
2000 / 5406
4000 / 5406
Finish H21.33.011.Cx26.MTG.02.007.3.01.06
Have samples: ['H21.33.016.Cx26.MTG.02.007.3.01.01']
There are totally 4696 cells in this dataset
2000 / 4696
4000 / 4696
Finish H21.33.016.Cx26.MTG.02.007.3.01.01
Have samples: ['H21.33.028.CX28.MTG.02.007.1.01.01']
There are totally 4421 cells in this dataset
2000 / 4421
4000 / 4421
Finish H21.33.028.CX28.MTG.02.007.1.01.01
Have samples: ['

In [3]:
import os
samples=[]
for filei in os.listdir(data_dir):
    if filei.find("_TypeExp.npz") >= 0:
        samples.append(filei.split("_TypeExp.npz")[0])
print(samples, "H21.33.019.Cx30.MTG.02.007.5.01.01" in samples)

['H20.33.004.Cx26.MTG.02.007.1.02.04', 'H20.33.004.Cx26.MTG.02.007.1.01.04', 'H20.33.004.Cx26.MTG.02.007.1.01.05', 'H21.33.011.Cx26.MTG.02.007.3.01.06', 'H21.33.016.Cx26.MTG.02.007.3.01.01', 'H21.33.028.CX28.MTG.02.007.1.01.01', 'H21.33.038.Cx20.MTG.02.007.3.01.02', 'H21.33.040.Cx22.MTG.02.007.3.03.03', 'H21.33.022.Cx26.MTG.02.007.2.M.02', 'H21.33.038.Cx20.MTG.02.007.3.01.04', 'H21.33.005.Cx18.MTG.02.007.02.04', 'H20.33.012.Cx24.MTG.02.007.1.01.01', 'H20.33.012.Cx24.MTG.02.007.1.03.03', 'H21.33.023.Cx26.MTG.02.007.1.03.01', 'H20.33.025.Cx28.MTG.02.007.1.01.02', 'H21.33.012.Cx26.MTG.02.007.1.01.06', 'H20.33.025.Cx28.MTG.02.007.1.01.04', 'H20.33.044.Cx26.MTG.02.007.1.01.04', 'H21.33.023.Cx26.MTG.02.007.1.03.05', 'H20.33.004.Cx26.MTG.02.007.1.02.03', 'H21.33.016.Cx26.MTG.02.007.3.01.02', 'H20.33.040.Cx25.MTG.02.007.1.01.03', 'H21.33.001.Cx22.MTG.02.007.1.01.04', 'H20.33.012.Cx24.MTG.02.007.1.03.02', 'H21.33.015.Cx26.MTG.02.007.1.2', 'H21.33.022.Cx26.MTG.02.007.2.M.03', 'H21.33.005.Cx18.MT