In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pandas as pd
import time

from model import GITIII,Loss_function
from calculate_PCC import Calculate_PCC

from dataloader import BC_evaluate_dataset

to_save_dir="../edges/"
data_dir="../../data/BC/processed1/"

ligands_info = torch.load("/".join(data_dir.split("/")[:-2]) + "/ligands.pth")
genes = torch.load("/".join(data_dir.split("/")[:-2]) + "/genes.pth")

my_model = GITIII(genes, ligands_info, node_dim=256, edge_dim=48, num_heads=2, n_layers=1, node_dim_small=16,att_dim=8)
my_model = my_model.cuda()

my_model.load_state_dict(torch.load("GRIT_best.pth"))

loss_func = Loss_function(genes, ligands_info).cuda()
evaluator=Calculate_PCC(genes,ligands_info)


cell_types_dict = {}
cnt = 0
for cell_typei in ['B_Cells', 'CD4+_T_Cells', 'CD8+_T_Cells', 'DCIS_1', 'DCIS_2', 'Endothelial', 'IRF7+_DCs', 'Invasive_Tumor', 'LAMP3+_DCs', 'Macrophages_1', 'Macrophages_2', 'Mast_Cells', 'Myoepi_ACTA2+', 'Myoepi_KRT15+', 'Perivascular-Like', 'Prolif_Invasive_Tumor', 'Stromal', 'Stromal_&_T_Cell_Hybrid', 'T_Cell_&_Tumor_Hybrid', 'Unlabeled']:
    cell_types_dict[cnt] = cell_typei
    cnt += 1

def evaluate_BC(sample):
    my_dataset = BC_evaluate_dataset(processed_dir=data_dir,sample=sample)
    my_dataloader = DataLoader(my_dataset, batch_size=1, num_workers=0, shuffle=False)

    length = len(my_dataloader)

    my_model.eval()

    results = {}
    with torch.no_grad():
        for (stepi, x) in enumerate(my_dataloader, start=1):
            x = {k: v.cuda() for k, v in x.items()}

            cell_type_name = [cell_types_dict[int(i.cpu())] for i in x["cell_types"].squeeze(dim=0)]

            y_pred = my_model(x)
            y = x["y"]
            lossi1, lossi2 = loss_func(y_pred, y)
            lossi1 = lossi1.cpu().detach()
            lossi2 = lossi2.cpu().detach()

            attention_score = y_pred[1][0].cpu().detach()
            attention_score = attention_score.squeeze(dim=0)
            attention_score = attention_score.permute(1, 2, 0)[0, :, :]
            # print(torch.topk(attention_score[0, :], k=30, dim=-1))

            edges = y_pred[1][1].cpu().detach()
            edges = edges.squeeze(dim=0).permute(1, 2, 0)[0, :, :]

            position_x = x["position_x"].cpu().detach().squeeze(dim=0)
            position_y = x["position_y"].cpu().detach().squeeze(dim=0)

            if stepi==1:
                results={
                    "edges": [edges],
                    "attention_score": [attention_score],
                    "position_x": [position_x],
                    "position_y": [position_y],
                    "cell_type_name": [cell_type_name],
                    "loss_all": [lossi1],
                    "loss_no_interact": [lossi2],
                    "y_pred": [y_pred[0].cpu().detach().squeeze(dim=0)],
                    "y": [y.cpu().detach().squeeze(dim=0)]
                }
            else:
                results["edges"].append(edges)
                results["attention_score"].append(attention_score)
                results["position_x"].append(position_x)
                results["position_y"].append(position_y)
                results["cell_type_name"].append(cell_type_name)
                results["loss_all"].append(lossi1)
                results["loss_no_interact"].append(lossi2)
                results["y_pred"].append(y_pred[0].cpu().detach().squeeze(dim=0))
                results["y"].append(y.cpu().detach().squeeze(dim=0))

            if stepi % 2000 == 0:
                print(stepi, "/", length)
    for keyi in results.keys():
        if keyi not in ["cell_type_name"]:
            results[keyi] = torch.stack(results[keyi], dim=0)
        else:
            results[keyi] = [results[keyi][j] for j in range(len(results[keyi]))]
    torch.save(results, to_save_dir + "edges_" + sample + ".pth")

    print("Finish", sample)
    return results


if __name__ == "__main__":
    sample1 = evaluate_BC('sample1_rep1')
    for keyi in sample1.keys():
        if keyi not in ["cell_type_name"]:
            print(keyi, sample1[keyi].shape)
        else:
            print(keyi, len(sample1[keyi]), len(sample1[keyi][0]))
    sample1 = evaluate_BC('sample1_rep2')

Have samples: ['sample1_rep1']
There are totally 159224 cells in this dataset
2000 / 159224
4000 / 159224
6000 / 159224
8000 / 159224
10000 / 159224
12000 / 159224
14000 / 159224
16000 / 159224
18000 / 159224
20000 / 159224
22000 / 159224
24000 / 159224
26000 / 159224
28000 / 159224
30000 / 159224
32000 / 159224
34000 / 159224
36000 / 159224
38000 / 159224
40000 / 159224
42000 / 159224
44000 / 159224
46000 / 159224
48000 / 159224
50000 / 159224
52000 / 159224
54000 / 159224
56000 / 159224
58000 / 159224
60000 / 159224
62000 / 159224
64000 / 159224
66000 / 159224
68000 / 159224
70000 / 159224
72000 / 159224
74000 / 159224
76000 / 159224
78000 / 159224
80000 / 159224
82000 / 159224
84000 / 159224
86000 / 159224
88000 / 159224
90000 / 159224
92000 / 159224
94000 / 159224
96000 / 159224
98000 / 159224
100000 / 159224
102000 / 159224
104000 / 159224
106000 / 159224
108000 / 159224
110000 / 159224
112000 / 159224
114000 / 159224
116000 / 159224
118000 / 159224
120000 / 159224
122000 / 159224