In [None]:
# import pytorch libraries
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from torch_geometric.deprecation import deprecated

import os
import csv
import copy
import math
import time
import random
import argparse
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm

from statsmodels.tools.eval_measures import mse
from decimal import Decimal
from scipy import stats
from scipy.stats import norm
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine
import matplotlib.pyplot as plt

In [None]:
# initial settings
# train the model on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

region_pick = ['Amygdala', 'Anterior_cingulate_cortex_BA24', 'Caudate_basal_ganglia', 
               'Cerebellar_Hemisphere', 'Frontal_Cortex_BA9', 'Hippocampus', 'Hypothalamus', 
               'Nucleus_accumbens_basal_ganglia', 'Putamen_basal_ganglia', 'Substantia_nigra']

# name difference
def find_allen_name(gtex_region):
    if gtex_region=='Cerebellar_Hemisphere':
        allen_name = 'Cerebellum'
    elif gtex_region=='Frontal_Cortex_BA9':
        allen_name = 'Cortex'
    else:
        allen_name = gtex_region
        
    return allen_name

# settings
all_ids = ['10021', '12876', '14380', '15496', '15697', '9861']

# path
allen_data_path = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/data/allen_data/allen/'
save_path = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/data/allen_data/quantile_normalized_allen/'

GeneExpression_allen_dict = {}
# iterate over all 6 subjects
for i in range(len(all_ids)):
    donor = all_ids[i]
    file_name = save_path + "normalized_expr_" + donor + ".csv"
    normalized_mat = pd.read_csv(file_name, header = 0)
    normalized_mat = normalized_mat.set_index('gene_symbol')
    GeneExpression_allen_dict[donor] = normalized_mat
    
    
ontology_path = allen_data_path + 'normalized_microarray_donor' + '9861' + '/Ontology.csv'
ontology = pd.read_csv(ontology_path, header = 0)
# From the ontology file, find the sub-regions in allen under gtex region
gtex_map_path = allen_data_path + "map_gtex_structure.txt"
gTex_map_dict = {}
print("Total number of regions in allen ontology:", ontology.shape[0])
for i in open(gtex_map_path):
    i = i.strip().split("\t")
    gtex_region = i[0].strip()
    allen_region = i[1].strip()
    if((allen_region == "none?") | (allen_region == 'pituitary body')):
        continue
    covered_allen_region = ontology.loc[(ontology['name']==allen_region) | ontology['structure_id_path'].str.startswith(ontology.loc[ontology['name']==allen_region, 'structure_id_path'].values[0]), 'id']
    gTex_map_dict[gtex_region] = covered_allen_region.tolist()
    print(gtex_region, "-->", allen_region, ";  number of regions in allen:", len(covered_allen_region))
print("\n")
    
    
intersected_region = GeneExpression_allen_dict['9861'].columns.tolist()
used_intersected_region_dict = {}
# unseen_intersected_region_dict = {}
for gtex_region, covered_allen_region in gTex_map_dict.items():
    used_region_list = [x for x in intersected_region if int(x) in covered_allen_region]
    used_intersected_region_dict[gtex_region] = used_region_list
    print(gtex_region, " # regions expired:", len(used_region_list))
num_used_region = sum(len(value) for value in used_intersected_region_dict.values())
print("Total number of intersected region between allen and gtex:", len(intersected_region))
print("Total number of used allen region for generating regions for gtex:", num_used_region)
print("Total number of unseen allen regions when generating regions for gtex:", len(intersected_region)-num_used_region)
print("\n")


# read the summarized allen data (in gtex format) into a dictionary
save_path = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/data/allen_data/quantile_normalized_allen/'
# find the file and read it into a dictionary
summarized_gtex_dict = {}
for file_name in os.listdir(save_path):
    if file_name.endswith('-gtex.txt'):
        key = file_name.split('-gtex.txt')[0]
        file_path = os.path.join(save_path, file_name)
        mat = pd.read_csv(file_path, sep='\t', index_col=0)
        #mat = mat.iloc[:-1]
        # Store the dataframe in the dictionary with the key
        summarized_gtex_dict[key] = mat
        

# Load gtex data
data_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/12052023/'
gt = pd.read_csv(data_dir+"new_normed_gtex_gtex_allen_gene.txt", low_memory=False, index_col=0, sep="\t")

# build a dictionary to count the freq of each subject 
sample_subject_list = gt.loc['subject'].tolist()
subject_count_dict = {}
for s in sample_subject_list:
    if s in subject_count_dict:
        subject_count_dict[s] = subject_count_dict[s] + 1
    else:
        subject_count_dict[s] = 1
# build a dictionary to count the freq of each region
sample_region_list = gt.loc['region'].tolist()
region_count_dict = {}
for s in sample_region_list:
    if s in region_count_dict:
        region_count_dict[s] = region_count_dict[s] + 1
    else:
        region_count_dict[s] = 1  

# find the subjects that have all 10 regions
pick_subject = [s for s, c in subject_count_dict.items() if c==10]
# build a dictionary for exp data for each subject in gtex who has all 10 brain regions
exp_gtex_dict = {}
for subject in pick_subject:
    submat = gt[gt.columns[gt.iloc[1]==subject]]
    submat.columns = submat.loc['region',:]
    submat = submat.iloc[2:,]
    submat.index.names = ['gene_id']
    submat = submat.sort_values(by=['gene_id'])
    submat = submat[region_pick]
    # And also, transform the dataframe in gtex from strings to numbers
    submat = submat.apply(pd.to_numeric, errors='ignore')
    # Take the average if more than 1 sample have the same gene names
    submat = submat.groupby(submat.index).mean()
    exp_gtex_dict[subject] = submat
# find 30 gtex subjects
sub_all_ids = list(exp_gtex_dict.keys())
    
    
# gene_module = pd.read_csv(allen_data_path+'41593_2015_BFnn4171_MOESM97_ESM.csv')
allen_gene_list = GeneExpression_allen_dict['9861'].index
gtex_gene_list = exp_gtex_dict['GTEX-N7MT'].index
overlapped_gene_list = [x for x in gtex_gene_list if x in allen_gene_list]  # 15044 genes here
# allen subject gene expression profile on the overlapped genes
exp_allen_dict = {}
for key, mat in GeneExpression_allen_dict.items():
    exp_allen_dict[key] = mat.loc[overlapped_gene_list]
# summarized gtex info for allen subjects on the overlapped genes
summ_gtex_info = {}
for key, mat in summarized_gtex_dict.items():
    summ_gtex_info[key] = mat.loc[overlapped_gene_list]
# rename the Cerebellum to Cerebellar_Hemisphere and Cortex to Frontal_Cortex_BA for allen people
for subject, mat in summ_gtex_info.items():
    mat.columns = exp_gtex_dict['GTEX-N7MT'].columns


# gene embeddings
g_emb_error = 0.035
g_emb_size = 2 ** 4
g_emb_path = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_91_103/'
g_emb_name = f'allen_gtex_gene_emb_all6subjects_size_{g_emb_size}_pearson_err_{g_emb_error}_intersected103.csv'
# np.savetxt(g_emb_path+g_emb_name, pretrain_g_emb, delimiter=',')
# read the pretrained gene embedding
pretrain_g_emb = np.genfromtxt(g_emb_path+g_emb_name, delimiter=',', dtype=np.float32)
import pickle
# Load the gene names from the file
with open(g_emb_path+g_emb_name+'_genenames.pkl', 'rb') as file:
    gene_emb_names_list = pickle.load(file)
    

# Build the edge list
# From the Ontology find the node relationship list
onto_file_path = 'normalized_microarray_donor10021/Ontology.csv'
onto_file_path = os.path.join(allen_data_path, onto_file_path)
ontology = pd.read_csv(onto_file_path)
ontology_id = ontology.loc[:, ['id', 'parent_structure_id']]
# set the parent node of 4005 to -1
ontology_id.iloc[0,1] = -1
ontology_id['parent_structure_id'] = ontology_id['parent_structure_id'].astype(int)
# View the nodes in a hierarchical way
node_child = [int(x) for x in intersected_region]
all_node = []
for i in range(1,20):
    if i==1:
        print(f"level {i}: {len(node_child)}")
        print(node_child)
        all_node.append(node_child)
    if len(node_child)==1:
        break
    if i!=1:
        node_parent = []
        for node in node_child:
            pos = ontology_id['id'].index[ontology_id['id']==node]
            # skip if it's already the ancestor
            if len(pos)==0: continue
            parent = ontology_id['parent_structure_id'][pos].values[0]
            node_parent.append(parent)
        node_parent = set(node_parent)
        node_child = [x for x in node_parent]
        print(f"level {i}: {len(node_child)}")
        print(node_child)
        all_node.append(node_child)
repeated_nodes = [x for y in all_node for x in y]
pick_nodes = set(repeated_nodes)
print(f"There are {len(pick_nodes)} nodes in total")

pick_nodes = [x for x in pick_nodes]
pick_nodes.sort()
# exclude the ancestor node (4005) and the '-1' node
intersected_nodes_child = pick_nodes[2:]
child_nodes_chr = list(exp_allen_dict['9861'].columns)
child_nodes = [int(x) for x in child_nodes_chr]
# append other hyper-level nodes to the pick_nodes
for x in intersected_nodes_child:
    if x not in child_nodes:
        child_nodes.append(x)
# find the parent nodes for the pick_nodes
parent_nodes = []
for x in child_nodes:
    pos = ontology_id['id'].index[ontology_id['id']==x][0]
    parent = ontology_id['parent_structure_id'][pos]
    parent_nodes.append(parent)
    
for _ in range(len(parent_nodes)):
    length = len(parent_nodes)
    for i in range(length):
        cid = child_nodes[i]
        pid = parent_nodes[i]
        if pid!=4005:
            # find how many children this parent node has
            count1 = parent_nodes.count(pid)
            # if this count is more than one, we don't remove this node
            if count1 > 1:
                continue
            # if this parent node only has one child, we remove it
            else:
                # find the position of this parent node in the children node list
                pidx = child_nodes.index(pid)
                # find the grandparent
                ppid = parent_nodes[pidx]
                # remove this parent and directly connect the child to its grandparent
                child_nodes[pidx] = cid
                child_nodes.pop(i)
                parent_nodes.pop(i)
                break
    if len(parent_nodes)==length:
        break
        
# put the leaves at the beginning
initial_nodes_chr = list(exp_allen_dict['9861'].columns)
new_child_nodes = [int(x) for x in initial_nodes_chr]
new_parent_nodes = []
for x in child_nodes:
    if x not in new_child_nodes:
        new_child_nodes.append(x)
for x in new_child_nodes:
    new_parent_nodes.append(parent_nodes[child_nodes.index(x)])

# put all nodes together in order so we can re-assign node id
all_nodes = new_child_nodes.copy()
for x in new_parent_nodes:
    if x not in all_nodes:
        all_nodes.append(x)

# re-index all the nodes and all the dataframe
child_nodes_idx = []
parent_nodes_idx = []
for node in new_child_nodes:
    child_nodes_idx.append(all_nodes.index(node))
for node in new_parent_nodes:
    parent_nodes_idx.append(all_nodes.index(node))


# Model pre-setting
# train the model on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# other settings
N_gene = len(exp_allen_dict['9861'])
N_node = len(child_nodes_idx)+1
n_node = len(intersected_region)
# define the edge list
# add edges between region nodes
edge_index_1 = [[child_nodes_idx[i], parent_nodes_idx[i]] for i in range(len(child_nodes_idx))]
edge_index_2 = [[parent_nodes_idx[i], child_nodes_idx[i]] for i in range(len(child_nodes_idx))]
edge_index = edge_index_1 + edge_index_2
for i in range(N_node):
    edge_index.append([i, i])

Total number of regions in allen ontology: 1839
Amygdala --> amygdala ;  number of regions in allen: 135
Anterior_cingulate_cortex_BA24 --> cingulate gyrus, frontal part ;  number of regions in allen: 7
Caudate_basal_ganglia --> caudate nucleus ;  number of regions in allen: 10
Cerebellum --> cerebellum ;  number of regions in allen: 95
Cortex --> frontal lobe ;  number of regions in allen: 87
Hippocampus --> hippocampal formation ;  number of regions in allen: 30
Hypothalamus --> hypothalamus ;  number of regions in allen: 176
Nucleus_accumbens_basal_ganglia --> nucleus accumbens ;  number of regions in allen: 3
Putamen_basal_ganglia --> putamen ;  number of regions in allen: 3
Substantia_nigra --> substantia nigra ;  number of regions in allen: 7


Amygdala  # regions expired: 6
Anterior_cingulate_cortex_BA24  # regions expired: 2
Caudate_basal_ganglia  # regions expired: 3
Cerebellum  # regions expired: 9
Cortex  # regions expired: 14
Hippocampus  # regions expired: 6
Hypothalamus  

In [None]:
# Load gtex data
data_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/12052023/'
gt = pd.read_csv(data_dir+"new_normed_gtex_gtex_allen_gene.txt", low_memory=False, index_col=0, sep="\t")
region_pick = ['Amygdala', 'Anterior_cingulate_cortex_BA24', 'Caudate_basal_ganglia', 
               'Cerebellar_Hemisphere', 'Frontal_Cortex_BA9', 'Hippocampus', 'Hypothalamus', 
               'Nucleus_accumbens_basal_ganglia', 'Putamen_basal_ganglia', 'Substantia_nigra']

# build a dictionary to count the freq of each subject 
sample_subject_list = gt.loc['subject'].tolist()
subject_region_mat = pd.DataFrame(np.zeros((len(set(sample_subject_list)), len(region_pick)), dtype=int))
subject_region_mat.index = sorted(set(sample_subject_list))
subject_region_mat.columns = region_pick
for i in range(gt.shape[1]):
    region = gt.loc['region'][i]
    subject = gt.loc['subject'][i]
    region_idx = region_pick.index(region)
    subject_idx = subject_region_mat.index.tolist().index(subject)
    subject_region_mat.iloc[subject_idx, region_idx] = 1

### Generate prediction for each tissue

In [None]:
def gen_GO_generator_by_region(t_epoch, f_epoch, missing_N, pred_region):
    data_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_91_103/Result/'
    if f_epoch==0:
        model_name = f'ATG_91_103_{pred_region}_epoch_{t_epoch}_bf_architecture.pth'
        weights_name = f'ATG_91_103_{pred_region}_epoch_{t_epoch}_bf_weights.pth'
    else:
        model_name = f'ATG_91_103_{pred_region}_trainable_GNN_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_architecture.pth'
        weights_name = f'ATG_91_103_{pred_region}_trainable_GNN_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_weights.pth'
    model = torch.load(data_dir+model_name)
    model.load_state_dict(torch.load(data_dir+weights_name))
    # find the predicted region and indices
    nodes_for_pred_gtex_region = used_intersected_region_dict[find_allen_name(pred_region)]
    nodes_for_pred_gtex_region_idx = [new_child_nodes.index(int(x)) for x in nodes_for_pred_gtex_region]
    # use lm model to fill other missing regions
    subject_w_10_r = [s for s in subject_region_mat.index if subject_region_mat.loc[s].sum()==10]
    go_pred_dict = {}
    for subject in subject_region_mat.index:
        if subject_region_mat.loc[subject, pred_region]==1:
            for i in range(gt.shape[1]):
                if ((gt.loc['region'][i]==pred_region) & (gt.loc['subject'][i]==subject)):
                    go_pred_dict[subject] = pd.to_numeric(gt.iloc[2:,i])
        else:
            # predict the missing region using LM from GTEx + GO:
            if subject_region_mat.loc[subject].sum()>=(10-missing_N):
                other_9_region_dict = {}
                # these are the existing regions besides the predicted region
                x_region = [r for r in region_pick if subject_region_mat.loc[subject, r]==1 if r!=pred_region]
                for region in x_region:
                    for i in range(gt.shape[1]):
                        if ((gt.loc['region'][i]==region) & (gt.loc['subject'][i]==subject)):
                            other_9_region_dict[region] = pd.to_numeric(gt.iloc[2:,i])
                # these are the regions to fill with lm model
                fill_region = [r for r in region_pick if r not in x_region if r!=pred_region]
                # build the inputs for GO model
                if len(fill_region)!=0:
                    # build the x regions matrix
                    train_dfs = [exp_gtex_dict[key].loc[:,x_region] for key in subject_w_10_r]
                    Xtrain = pd.concat(train_dfs, axis=0, ignore_index=True)
                    Xtrain = sm.add_constant(Xtrain)
                    pick_col = [i for i in range(gt.shape[1]) if gt.loc['subject'][i]==subject]
                    Xtest = gt.iloc[:,pick_col]
                    Xtest.columns = Xtest.loc['region']
                    Xtest = Xtest.drop(['region', 'subject'])
                    Xtest = Xtest.loc[:,x_region]
                    Xtest = Xtest.apply(pd.to_numeric)
                    Xtest = sm.add_constant(Xtest)
                    for lm_region in fill_region:
                        y_region = lm_region
                        # build the y region
                        train_preds = [exp_gtex_dict[key].loc[:,y_region] for key in subject_w_10_r]
                        ytrain = pd.concat(train_preds, axis=0, ignore_index=True)
                        # build the lm model
                        fmod = sm.OLS(ytrain, Xtrain).fit()
                        # prediction
                        pred = fmod.predict(Xtest)
                        other_9_region_dict[lm_region] = pred
                # build the inputs for GO model
                sorted_region = [r for r in region_pick if r!=pred_region]
                input_mat = pd.DataFrame.from_dict(other_9_region_dict)
                input_mat.index = gt.index[2:]
                input_mat = input_mat.loc[:,sorted_region]
                # run GO model
                gen_tuple = torch.tensor(np.arange(N_gene), dtype=torch.long)
                x_reg_exp = torch.tensor(input_mat.values).float()           
                with torch.no_grad():
                    concat_pred = model(nodes_for_pred_gtex_region_idx, gen_tuple, x_reg_exp, edge_index, N_gene).reshape(-1)
                go_pred_dict[subject] = concat_pred.detach().cpu()
    # generate the final output
    pred_mat = pd.DataFrame.from_dict(go_pred_dict)
    pred_mat.index = gt.index[2:]
    
    return pred_mat

In [None]:
# region_pick = ['Amygdala', 'Anterior_cingulate_cortex_BA24', 'Caudate_basal_ganglia', 
#                'Cerebellar_Hemisphere', 'Frontal_Cortex_BA9', 'Hippocampus', 'Hypothalamus', 
#                'Nucleus_accumbens_basal_ganglia', 'Putamen_basal_ganglia', 'Substantia_nigra']

# # build a dictionary to count the freq of each subject 
# sample_subject_list = gt.loc['subject'].tolist()
# subject_region_mat = pd.DataFrame(np.zeros((len(set(sample_subject_list)), len(region_pick)), dtype=int))
# subject_region_mat.index = sorted(set(sample_subject_list))
# subject_region_mat.columns = region_pick
# for i in range(gt.shape[1]):
#     region = gt.loc['region'][i]
#     subject = gt.loc['subject'][i]
#     region_idx = region_pick.index(region)
#     subject_idx = subject_region_mat.index.tolist().index(subject)
#     subject_region_mat.iloc[subject_idx, region_idx] = 1

# # generate prediction for every tissue
# t_epoch = 300
# f_epoch = 500
# missing_N = 5
# # generate prediction
# for pred_region in region_pick:
#     go_mat = gen_GO_generator_by_region(t_epoch, f_epoch, missing_N, pred_region)
#     save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_8_or_less/Prediction/'
#     csv_file = f'{pred_region}_{t_epoch}_{f_epoch}_{missing_N}_LMfromGTEx.csv'
#     go_mat.to_csv(save_dir+csv_file, index=True)

### Generate prediction for each subject

In [25]:
def gen_gtex_regions_LM_fromGTEx(subject):
    # use lm model to fill other missing regions
    subject_w_10_r = [s for s in subject_region_mat.index if subject_region_mat.loc[s].sum()==10]
    gtex_region_to_predict_dict = {}
    # find existing regions
    x_region = [r for r in region_pick if subject_region_mat.loc[subject, r]==1]
    for region in x_region:
        for i in range(gt.shape[1]):
            if ((gt.loc['region'][i]==region) & (gt.loc['subject'][i]==subject)):
                gtex_region_to_predict_dict[region] = pd.to_numeric(gt.iloc[2:,i])
    # these are the regions to be filled with lm model
    fill_region = [r for r in region_pick if r not in x_region]
    if len(fill_region)!=0:
        # build the x regions matrix
        train_dfs = [exp_gtex_dict[key].loc[:,x_region] for key in subject_w_10_r]
        Xtrain = pd.concat(train_dfs, axis=0, ignore_index=True)
        Xtrain = sm.add_constant(Xtrain)
        pick_col = [i for i in range(gt.shape[1]) if gt.loc['subject'][i]==subject]
        Xtest = gt.iloc[:,pick_col]
        Xtest.columns = Xtest.loc['region']
        Xtest = Xtest.drop(['region', 'subject'])
        Xtest = Xtest.loc[:,x_region]
        Xtest = Xtest.apply(pd.to_numeric)
        Xtest = sm.add_constant(Xtest)
        for lm_region in fill_region:
            y_region = lm_region
            # build the y region
            train_preds = [exp_gtex_dict[key].loc[:,y_region] for key in subject_w_10_r]
            ytrain = pd.concat(train_preds, axis=0, ignore_index=True)
            # build the lm model
            fmod = sm.OLS(ytrain, Xtrain).fit()
            # prediction
            pred = fmod.predict(Xtest)
            gtex_region_to_predict_dict[lm_region] = pred
    # build the inputs for GO model
    input_mat = pd.DataFrame.from_dict(gtex_region_to_predict_dict)
    input_mat.index = gt.index[2:]
    input_mat = input_mat.loc[:,region_pick]
        
    return input_mat

# generate go model input using linear regression model
def gen_allen_regions_LM_fromGTEx(t_epoch, f_epoch, status, subject):
    # generate the input 10 gtex regions
    input_mat = gen_gtex_regions_LM_fromGTEx(subject)
    # model directory and load the model
    data_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_10_103/Result/'
    if f_epoch==0:
        model_name = f'ATG_10_103_epoch_{t_epoch}_bf_architecture.pth'
        weights_name = f'ATG_10_103_epoch_{t_epoch}_bf_weights.pth'
    else:
        if status=="freeze":
            model_name = f'ATG_10_103_freeze_GNN_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_architecture.pth'
            weights_name = f'ATG_10_103_freeze_GNN_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_weights.pth'
        else:
            model_name = f'ATG_10_103_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_architecture.pth'
            weights_name = f'ATG_10_103_af_epoch_{t_epoch}_finetuning_epoch_{f_epoch}_weights.pth'            
    model = torch.load(data_dir+model_name).to(device)
    model.load_state_dict(torch.load(data_dir+weights_name))
    # run GO model and get the prediction
    model.eval()
    exp_dict = {}
    gtex_exp_10 = torch.tensor(input_mat.values).float().to(device)
    gen_tuple = torch.tensor(np.arange(N_gene), dtype=torch.long).to(device)
    keys = list(exp_allen_dict[all_ids[0]].columns)
    for region in keys:
        reg_id = keys.index(region)
        pred = model([reg_id], gen_tuple, gtex_exp_10, edge_index, N_gene).reshape(-1)
        exp_dict[region] = pred.cpu().detach().numpy()
    exp_mat = pd.DataFrame(exp_dict)
    exp_mat.index = gt.index[2:]
    
    return exp_mat

In [36]:
# generate prediction for every tissue
t_epoch = 300
f_epoch = 0
status = 'trainable'
missing_N = 5
# generate prediction
go_prediction_dict = {}
include_subject = [s for s in subject_region_mat.index if subject_region_mat.loc[s].sum()>=(10-missing_N)]
for subject in include_subject:
    go_mat = gen_allen_regions_LM_fromGTEx(t_epoch, f_epoch, status, subject)
    save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_8_or_less/Prediction/'
    folder_name = f'{t_epoch}_{f_epoch}_by_subject/'
    csv_file = f'{subject}_{t_epoch}_{f_epoch}_{status}_LMfromGTEx.csv'
    # go_mat.to_csv(save_dir+folder_name+csv_file, index=True)
    go_prediction_dict[subject] = go_mat
    
# save the dictionary
save_dir = '/project/pi_rachel_melamed_uml_edu/Jianfeng/Allen/src/Pytorch/02162024/ATG_8_or_less/Prediction/'
if f_epoch==0:
    dict_name = f'gtex_allen_region_epoch_{t_epoch}_missing_N_{missing_N}.pickle'
else:
    dict_name = f'gtex_allen_region_trainable_GNN_af_epoch_{t_epoch}_fepoch_{f_epoch}_missing_N_{missing_N}.pickle'
# Writing the dictionary to a file using pickle
with open(save_dir+dict_name, 'wb') as f:
    pickle.dump(go_prediction_dict, f)
# # Reading the pickled data from file
# with open(save_dir+dict_name, 'rb') as f:
#     go_prediction_dict = pickle.load(f)