In [None]:
cd ../..

In [None]:
import numpy as np
import torch
from omegaconf import OmegaConf
from sklearn.metrics import *
from tqdm.auto import tqdm

import stereo
from stereo.utils.results import *
from stereo.utils.stats import htest

npi = np.intersect1d
npc = np.concatenate
npu = np.unique
npd = np.setdiff1d

In [None]:
model_name = 'STEREO-GCN'
dataset_name = 'AT2'

In [None]:
cfg = OmegaConf.load(f'configs/dataset/GRN-{dataset_name}.yaml')
dataset = load_dataset(cfg)
training_tfs = dataset.genes[dataset.gene_is_tf]

### Run seeds and IDs

#### Hier-Prox Step 1

In [None]:
hier_runs = load_run_ids(f'results/runs/{dataset_name}/{model_name}.json', 'DEG4-hier')

In [None]:
counters = [np.zeros(dataset.gene_is_tf.sum()) for _ in range(dataset.n_seq)]
theta_d = [np.zeros(dataset.gene_is_tf.sum()) for _ in range(dataset.n_seq)]

for seed, run in tqdm(hier_runs.items()):
    if dataset_name == 'PBMC':
        ckpt_path = f'PBMCGeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    else:
        ckpt_path = f'GeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    hp = torch.load(ckpt_path, map_location='cpu')['callbacks']['HierProx']
    for idx, theta in enumerate(hp['theta_seq_']):
        counters[idx][theta > 0] += 1
        theta_d[idx] += theta.numpy()
theta_d = np.stack(theta_d)

In [None]:
slist = {}

consensus = 5

for idx, counter in enumerate(counters):
    selected_indices = (counter >= consensus).nonzero()[0]
    print(selected_indices.tolist())
    selected_tfs = training_tfs[counter >= consensus]
    print(selected_tfs.tolist())
    slist[idx] = (selected_tfs, selected_indices)
    # olap = npi(tfs, selected_tfs)
    # print(olap)
    # pval = htest(universe=training_tfs, draws=selected_tfs, successes=tfs)
    # print(f"T={idx}, overlap={olap.size} / {len(selected_tfs)}", )
    # print('====')

In [None]:
allidx = [v[1] for v in list(slist.values())]
allgene = [v[0] for v in list(slist.values())]
catted = npu(npc(allgene))
print(catted, len(npu(npc(allgene))))
print(npu(npc(allidx)).tolist())

In [None]:
sel = [v[0] for v in list(slist.values())]
print(npi(npi(sel[0], sel[1]), sel[2]))
print(npi(sel[0], sel[1]))
print(npi(sel[1], sel[2]))
print(npi(sel[0], sel[2]))

In [None]:
npi(catted, dataset.snc_tfs)

In [None]:
a = ['BTF3', 'FOS', 'JUNB', 'JUND', 'KLF2', 'NFKBIA', 'NPM1', 'PTMA', 'SF1', 'TSC22D3']
for g in a:
    print(f"\\textit{{{g}}}", end=', ')

#### Full Step 2

In [None]:
full_runs = load_run_ids(f'results/runs/{dataset_name}/{model_name}.json', 'DEG4-full')

In [None]:
As = []

for seed, run in tqdm(full_runs.items()):
    if dataset_name == 'PBMC':
        ckpt_path = f'PBMCGeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    else:
        ckpt_path = f'GeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    module = stereo.STEREO_GCN_Module.load_from_checkpoint(ckpt_path, map_location='cpu')
    As.append([A.detach().numpy() for A in module.A_seq_])

As = np.stack(As)
print(As.shape)

In [None]:
t_to_selected = aggregate_As(
    As,
    tfs=dataset.genes[module.sources_mask_seq[0]],
    genes=dataset.genes,
    consensus=3,
)

In [None]:
write_graphs(t_to_selected, f'results/graphs/{dataset_name}/{model_name}-{dataset_name}')