In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pandas as pd
import time
import gc

from dataloader import NSCLC_evaluate_dataset
from model import GITIII,Loss_function
from calculate_PCC import Calculate_PCC

to_save_dir="../edges/"
data_dir = "../../data/NSCLC/processed/"

ligands_info = torch.load("/".join(data_dir.split("/")[:-2]) + "/ligands.pth")
genes = torch.load("/".join(data_dir.split("/")[:-2]) + "/genes.pth")

my_model = GITIII(genes, ligands_info, node_dim=256, edge_dim=128, num_heads=2, n_layers=1, node_dim_small=16,att_dim=8)
my_model = my_model.cuda()
loss_func = Loss_function(genes, ligands_info).cuda()
evaluator=Calculate_PCC(genes,ligands_info)
my_model.load_state_dict(torch.load("GRIT_best.pth"))
cell_types_dict = {}
cnt = 0
for cell_typei in ['B-cell', 'NK', 'T CD4 memory', 'T CD4 naive', 'T CD8 memory', 'T CD8 naive', 'Treg', 'endothelial', 'epithelial', 'fibroblast', 'mDC', 'macrophage', 'mast', 'monocyte', 'neutrophil', 'pDC', 'plasmablast', 'tumor 12', 'tumor 13', 'tumor 5', 'tumor 6', 'tumor 9']:
    cell_types_dict[cnt] = cell_typei
    cnt += 1

def evaluate_NSCLC(sample):
    my_dataset = NSCLC_evaluate_dataset(processed_dir=data_dir,sample=sample)
    my_dataloader = DataLoader(my_dataset, batch_size=1, num_workers=0, shuffle=False)

    length = len(my_dataloader)

    my_model.eval()

    results = []
    with torch.no_grad():
        for (stepi, x) in enumerate(my_dataloader, start=1):
            x = {k: v.cuda() for k, v in x.items()}

            cell_type_name = [cell_types_dict[int(i.cpu())] for i in x["cell_types"].squeeze(dim=0)]

            y_pred = my_model(x)
            y = x["y"]
            lossi1, lossi2 = loss_func(y_pred, y)
            lossi1 = lossi1.cpu().detach()
            lossi2 = lossi2.cpu().detach()

            attention_score = y_pred[1][0].cpu().detach()
            attention_score = attention_score.squeeze(dim=0)
            attention_score = attention_score.permute(1, 2, 0)[0, :, :]
            # print(torch.topk(attention_score[0, :], k=30, dim=-1))

            edges = y_pred[1][1].cpu().detach()
            edges = edges.squeeze(dim=0).permute(1, 2, 0)[0, :, :]

            position_x = x["position_x"].cpu().detach().squeeze(dim=0)
            position_y = x["position_y"].cpu().detach().squeeze(dim=0)

            to_save_dict = {
                "edges": edges,
                "attention_score": attention_score,
                "position_x": position_x,
                "position_y": position_y,
                "cell_type_name": cell_type_name,
                "loss_all": lossi1,
                "loss_no_interact": lossi2,
                "y_pred": y_pred[0].cpu().detach().squeeze(dim=0),
                "y": y.cpu().detach().squeeze(dim=0)
            }
            results.append(to_save_dict)

            if stepi % 2000 == 0:
                print(stepi, "/", length)
    concatenated_results = {}
    for keyi in results[0].keys():
        if keyi not in ["cell_type_name"]:
            concatenated_results[keyi] = torch.stack([results[j][keyi] for j in range(len(results))], dim=0)
        else:
            concatenated_results[keyi] = [results[j][keyi] for j in range(len(results))]
    torch.save(concatenated_results, to_save_dir + "edges_" + sample + ".pth")

    print("Finish", sample)
    return concatenated_results


In [2]:
samples=['Lung6', 'Lung13', 'Lung5_Rep1', 'Lung5_Rep3', 'Lung5_Rep2', 'Lung9_Rep1', 'Lung9_Rep2', 'Lung12']
for samplei in samples:
    evaluate_NSCLC(samplei)
    gc.collect()

Have samples: ['Lung6']
There are totally 77025 cells in this dataset
2000 / 77025
4000 / 77025
6000 / 77025
8000 / 77025
10000 / 77025
12000 / 77025
14000 / 77025
16000 / 77025
18000 / 77025
20000 / 77025
22000 / 77025
24000 / 77025
26000 / 77025
28000 / 77025
30000 / 77025
32000 / 77025
34000 / 77025
36000 / 77025
38000 / 77025
40000 / 77025
42000 / 77025
44000 / 77025
46000 / 77025
48000 / 77025
50000 / 77025
52000 / 77025
54000 / 77025
56000 / 77025
58000 / 77025
60000 / 77025
62000 / 77025
64000 / 77025
66000 / 77025
68000 / 77025
70000 / 77025
72000 / 77025
74000 / 77025
76000 / 77025
Finish Lung6
Have samples: ['Lung13']
There are totally 77465 cells in this dataset
2000 / 77465
4000 / 77465
6000 / 77465
8000 / 77465
10000 / 77465
12000 / 77465
14000 / 77465
16000 / 77465
18000 / 77465
20000 / 77465
22000 / 77465
24000 / 77465
26000 / 77465
28000 / 77465
30000 / 77465
32000 / 77465
34000 / 77465
36000 / 77465
38000 / 77465
40000 / 77465
42000 / 77465
44000 / 77465
46000 / 77465
