In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import pandas as pd
import time

from model import GITIII,Loss_function
from calculate_PCC import Calculate_PCC

from dataloader import Mouse_brain_evaluate_dataset

to_save_dir="../edges/"
data_dir="../../data/Mouse_brain/processed1/"

ligands_info = torch.load("/".join(data_dir.split("/")[:-2]) + "/ligands.pth")
genes = torch.load("/".join(data_dir.split("/")[:-2]) + "/genes.pth")

my_model = GITIII(genes, ligands_info, node_dim=256, edge_dim=48, num_heads=2,n_layers=1, node_dim_small=16,att_dim=8)
my_model = my_model.cuda()

my_model.load_state_dict(torch.load("GRIT_best.pth"))

loss_func = Loss_function(genes, ligands_info).cuda()
evaluator=Calculate_PCC(genes,ligands_info)

cell_types_dict = {}
cnt = 0
for cell_typei in ['Astro', 'Endo', 'L2/3 IT', 'L4/5 IT', 'L5 ET', 'L5 IT', 'L5/6 NP', 'L6 CT', 'L6 IT',
                      'L6 IT Car3', 'L6b', 'Lamp5', 'Micro', 'OPC', 'Oligo', 'PVM', 'Peri', 'Pvalb', 'SMC', 'Sncg',
                      'Sst', 'VLMC', 'Vip', 'other']:
    cell_types_dict[cnt] = cell_typei
    cnt += 1

def evaluate_Brain_MERFISH(sample):
    my_dataset = dataset = Mouse_brain_evaluate_dataset(processed_dir=data_dir, sample=sample)
    my_dataloader = DataLoader(my_dataset, batch_size=1, num_workers=0,shuffle=False)

    length=len(my_dataloader)

    my_model.eval()

    results=[]
    with torch.no_grad():
        for (stepi, x) in enumerate(my_dataloader, start=1):
            x = {k: v.cuda() for k, v in x.items()}

            cell_type_name=[cell_types_dict[int(i.cpu())] for i in x["cell_types"].squeeze(dim=0)]

            y_pred = my_model(x)
            y = x["y"]
            lossi1, lossi2 = loss_func(y_pred, y)
            lossi1=lossi1.cpu().detach()
            lossi2=lossi2.cpu().detach()

            attention_score = y_pred[1][0].cpu().detach()
            attention_score = attention_score.squeeze(dim=0)
            attention_score=attention_score.permute(1,2,0)[0,:,:]
            #print(torch.topk(attention_score[0, :], k=30, dim=-1))

            edges = y_pred[1][1].cpu().detach()
            edges = edges.squeeze(dim=0).permute(1,2,0)[0,:,:]

            position_x = x["position_x"].cpu().detach().squeeze(dim=0)
            position_y = x["position_y"].cpu().detach().squeeze(dim=0)

            to_save_dict = {
                "edges": edges,
                "attention_score": attention_score,
                "position_x": position_x,
                "position_y": position_y,
                "cell_type_name": cell_type_name,
                "loss_all": lossi1,
                "loss_no_interact": lossi2,
                "y_pred": y_pred[0].cpu().detach().squeeze(dim=0),
                "y": y.cpu().detach().squeeze(dim=0)
            }
            results.append(to_save_dict)

            if stepi%500==0:
                print(stepi,"/",length)
    concatenated_results={}
    for keyi in results[0].keys():
        if keyi not in ["cell_type_name"]:
            concatenated_results[keyi]=torch.stack([results[j][keyi] for j in range(len(results))],dim=0)
        else:
            concatenated_results[keyi]=[results[j][keyi] for j in range(len(results))]
    torch.save(concatenated_results,to_save_dir+"edges_"+sample+".pth")
    
    print("Finish", sample)
    return concatenated_results

if __name__=="__main__":
    mouse1_slice201=evaluate_Brain_MERFISH('mouse1_slice201')

Have samples: ['mouse1_slice201']
There are totally 6137 cells in this dataset
500 / 6137
1000 / 6137
1500 / 6137
2000 / 6137
2500 / 6137
3000 / 6137
3500 / 6137
4000 / 6137
4500 / 6137
5000 / 6137
5500 / 6137
6000 / 6137
Finish mouse1_slice201


In [2]:
for keyi in mouse1_slice201.keys():
    if keyi not in ["cell_type_name"]:
        print(keyi,mouse1_slice201[keyi].shape)
    else:
        print(keyi,len(mouse1_slice201[keyi]),len(mouse1_slice201[keyi][0]))

edges torch.Size([6137, 49, 254])
attention_score torch.Size([6137, 49, 254])
position_x torch.Size([6137, 50])
position_y torch.Size([6137, 50])
cell_type_name 6137 50
loss_all torch.Size([6137])
loss_no_interact torch.Size([6137])
y_pred torch.Size([6137, 254])
y torch.Size([6137, 254])


In [3]:
samples=['mouse1_slice1', 'mouse1_slice10', 'mouse1_slice102', 'mouse1_slice112', 'mouse1_slice122', 'mouse1_slice131', 'mouse1_slice153', 'mouse1_slice162', 'mouse1_slice170', 'mouse1_slice180', 'mouse1_slice190', 'mouse1_slice200', 'mouse1_slice201', 'mouse1_slice21', 'mouse1_slice212', 'mouse1_slice221', 'mouse1_slice232', 'mouse1_slice241', 'mouse1_slice251', 'mouse1_slice260', 'mouse1_slice271', 'mouse1_slice283', 'mouse1_slice291', 'mouse1_slice301', 'mouse1_slice31', 'mouse1_slice313', 'mouse1_slice326', 'mouse1_slice40', 'mouse1_slice50', 'mouse1_slice62', 'mouse1_slice71', 'mouse1_slice81', 'mouse1_slice91', 'mouse2_slice1', 'mouse2_slice10', 'mouse2_slice109', 'mouse2_slice119', 'mouse2_slice129', 'mouse2_slice139', 'mouse2_slice151', 'mouse2_slice160', 'mouse2_slice169', 'mouse2_slice189', 'mouse2_slice20', 'mouse2_slice201', 'mouse2_slice209', 'mouse2_slice219', 'mouse2_slice229', 'mouse2_slice249', 'mouse2_slice261', 'mouse2_slice270', 'mouse2_slice280', 'mouse2_slice289', 'mouse2_slice300', 'mouse2_slice309', 'mouse2_slice31', 'mouse2_slice319', 'mouse2_slice40', 'mouse2_slice50', 'mouse2_slice61', 'mouse2_slice70', 'mouse2_slice79', 'mouse2_slice90', 'mouse2_slice99']
for samplei in samples:
    evaluate_Brain_MERFISH(samplei)

Have samples: ['mouse1_slice1']
There are totally 2351 cells in this dataset
500 / 2351
1000 / 2351
1500 / 2351
2000 / 2351
Finish mouse1_slice1
Have samples: ['mouse1_slice10']
There are totally 2393 cells in this dataset
500 / 2393
1000 / 2393
1500 / 2393
2000 / 2393
Finish mouse1_slice10
Have samples: ['mouse1_slice102']
There are totally 6636 cells in this dataset
500 / 6636
1000 / 6636
1500 / 6636
2000 / 6636
2500 / 6636
3000 / 6636
3500 / 6636
4000 / 6636
4500 / 6636
5000 / 6636
5500 / 6636
6000 / 6636
6500 / 6636
Finish mouse1_slice102
Have samples: ['mouse1_slice112']
There are totally 4920 cells in this dataset
500 / 4920
1000 / 4920
1500 / 4920
2000 / 4920
2500 / 4920
3000 / 4920
3500 / 4920
4000 / 4920
4500 / 4920
Finish mouse1_slice112
Have samples: ['mouse1_slice122']
There are totally 6082 cells in this dataset
500 / 6082
1000 / 6082
1500 / 6082
2000 / 6082
2500 / 6082
3000 / 6082
3500 / 6082
4000 / 6082
4500 / 6082
5000 / 6082
5500 / 6082
6000 / 6082
Finish mouse1_slice1