In [1]:
cd ../..

/home/paperspace/time-varying-graphs


In [2]:
import numpy as np
import torch
from omegaconf import OmegaConf
from sklearn.metrics import *
from tqdm.auto import tqdm

import stereo
from stereo.utils.results import *
from stereo.utils.stats import htest

npi = np.intersect1d
npc = np.concatenate
npu = np.unique
npd = np.setdiff1d

In [20]:
model_name = 'STEREO-GCN'
dataset_name = 'AT2'

In [21]:
cfg = OmegaConf.load(f'configs/dataset/GRN-{dataset_name}.yaml')
dataset = load_dataset(cfg)
training_tfs = dataset.genes[dataset.gene_is_tf]

Keeping 2263 genes.


### Run seeds and IDs

#### Hier-Prox Step 1

In [11]:
hier_runs = load_run_ids(f'results/runs/{dataset_name}/{model_name}.json', 'DEG4-hier')

In [12]:
counters = [np.zeros(dataset.gene_is_tf.sum()) for _ in range(dataset.n_seq)]
theta_d = [np.zeros(dataset.gene_is_tf.sum()) for _ in range(dataset.n_seq)]

for seed, run in tqdm(hier_runs.items()):
    if dataset_name == 'PBMC':
        ckpt_path = f'PBMCGeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    else:
        ckpt_path = f'GeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    hp = torch.load(ckpt_path, map_location='cpu')['callbacks']['HierProx']
    for idx, theta in enumerate(hp['theta_seq_']):
        counters[idx][theta > 0] += 1
        theta_d[idx] += theta.numpy()
theta_d = np.stack(theta_d)

  0%|          | 0/10 [00:00<?, ?it/s]

In [13]:
slist = {}

consensus = 5

for idx, counter in enumerate(counters):
    selected_indices = (counter >= consensus).nonzero()[0]
    print(selected_indices.tolist())
    selected_tfs = training_tfs[counter >= consensus]
    print(selected_tfs.tolist())
    slist[idx] = (selected_tfs, selected_indices)
    # olap = npi(tfs, selected_tfs)
    # print(olap)
    # pval = htest(universe=training_tfs, draws=selected_tfs, successes=tfs)
    # print(f"T={idx}, overlap={olap.size} / {len(selected_tfs)}", )
    # print('====')

[4, 33, 36, 38, 46, 51, 53, 55, 56, 86, 93, 103, 106, 111, 115, 121]
['YBX1', 'HES1', 'EPAS1', 'ID3', 'NR4A1', 'IRF1', 'ID1', 'KLF2', 'ZFP36', 'ETS2', 'IFI16', 'FOS', 'CEBPB', 'JUN', 'NPM1', 'TCF4']
[36, 38, 53, 55, 56, 93, 111, 121]
['EPAS1', 'ID3', 'ID1', 'KLF2', 'ZFP36', 'IFI16', 'JUN', 'TCF4']
[4, 36, 38, 46, 53, 55, 56, 84, 86, 93, 103, 105, 109, 111, 115, 121]
['YBX1', 'EPAS1', 'ID3', 'NR4A1', 'ID1', 'KLF2', 'ZFP36', 'TSC22D3', 'ETS2', 'IFI16', 'FOS', 'MAL', 'NUPR1', 'JUN', 'NPM1', 'TCF4']
[4, 36, 38, 53, 55, 56, 84, 86, 93, 111, 115]
['YBX1', 'EPAS1', 'ID3', 'ID1', 'KLF2', 'ZFP36', 'TSC22D3', 'ETS2', 'IFI16', 'JUN', 'NPM1']
[4, 36, 38, 42, 51, 53, 55, 56, 86, 93, 103, 111, 115]
['YBX1', 'EPAS1', 'ID3', 'EGR1', 'IRF1', 'ID1', 'KLF2', 'ZFP36', 'ETS2', 'IFI16', 'FOS', 'JUN', 'NPM1']


In [14]:
allidx = [v[1] for v in list(slist.values())]
allgene = [v[0] for v in list(slist.values())]
catted = npu(npc(allgene))
print(catted, len(npu(npc(allgene))))
print(npu(npc(allidx)).tolist())

['CEBPB' 'EGR1' 'EPAS1' 'ETS2' 'FOS' 'HES1' 'ID1' 'ID3' 'IFI16' 'IRF1'
 'JUN' 'KLF2' 'MAL' 'NPM1' 'NR4A1' 'NUPR1' 'TCF4' 'TSC22D3' 'YBX1' 'ZFP36'] 20
[4, 33, 36, 38, 42, 46, 51, 53, 55, 56, 84, 86, 93, 103, 105, 106, 109, 111, 115, 121]


In [10]:
sel = [v[0] for v in list(slist.values())]
print(npi(npi(sel[0], sel[1]), sel[2]))
print(npi(sel[0], sel[1]))
print(npi(sel[1], sel[2]))
print(npi(sel[0], sel[2]))

['EGR1' 'ELF3' 'FOS' 'HOPX' 'JUN' 'NPM1' 'NUPR1' 'TSC22D3' 'YBX1' 'ZFP36']
['EGR1' 'ELF3' 'FOS' 'FOSB' 'HOPX' 'JUN' 'NPM1' 'NUPR1' 'TSC22D3' 'YBX1'
 'ZFP36']
['EGR1' 'ELF3' 'FOS' 'HOPX' 'ID1' 'JUN' 'NPM1' 'NR4A1' 'NUPR1' 'TSC22D3'
 'YBX1' 'ZFP36']
['BTG2' 'EGR1' 'ELF3' 'FOS' 'HOPX' 'JUN' 'NPM1' 'NUPR1' 'TSC22D3' 'YBX1'
 'ZFP36']


In [13]:
npi(catted, dataset.snc_tfs)

array(['CEBPB', 'CREG1', 'CTNNB1', 'DEK', 'ETS2', 'FOS', 'ID1', 'JUN',
       'NPM1', 'RBL2', 'YBX1', 'ZFP36'], dtype='<U29')

In [4]:
a = ['BTF3', 'FOS', 'JUNB', 'JUND', 'KLF2', 'NFKBIA', 'NPM1', 'PTMA', 'SF1', 'TSC22D3']
for g in a:
    print(f"\\textit{{{g}}}", end=', ')

\textit{BTF3}, \textit{FOS}, \textit{JUNB}, \textit{JUND}, \textit{KLF2}, \textit{NFKBIA}, \textit{NPM1}, \textit{PTMA}, \textit{SF1}, \textit{TSC22D3}, 

#### Full Step 2

In [22]:
full_runs = load_run_ids(f'results/runs/{dataset_name}/{model_name}.json', 'DEG4-full')

In [23]:
As = []

for seed, run in tqdm(full_runs.items()):
    if dataset_name == 'PBMC':
        ckpt_path = f'PBMCGeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    else:
        ckpt_path = f'GeneRegPseudotimeDataset_logs/{run}/checkpoints/last.ckpt'
    module = stereo.STEREO_GCN_Module.load_from_checkpoint(ckpt_path, map_location='cpu')
    As.append([A.detach().numpy() for A in module.A_seq_])

As = np.stack(As)
print(As.shape)

  0%|          | 0/10 [00:00<?, ?it/s]

(10, 5, 2263, 22)


In [24]:
t_to_selected = aggregate_As(
    As,
    tfs=dataset.genes[module.sources_mask_seq[0]],
    genes=dataset.genes,
    consensus=3,
)

t=0	 Selected 724 edges	N. TFs = 22	N. genes = 547	
t=1	 Selected 798 edges	N. TFs = 22	N. genes = 637	
t=2	 Selected 818 edges	N. TFs = 22	N. genes = 662	
t=3	 Selected 786 edges	N. TFs = 22	N. genes = 628	
t=4	 Selected 751 edges	N. TFs = 22	N. genes = 614	


In [26]:
write_graphs(t_to_selected, f'results/graphs/{dataset_name}/{model_name}-{dataset_name}')