In [1]:
%reload_ext autoreload
%autoreload 2


In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
import hydra
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
# setup hydra config global for loading this notebook
hydra.initialize(config_path="configs", version_base=None)
cfg = hydra.compose(config_name="expr_baseline")


In [5]:
from main import setup_environment, get_checkpoint_info, is_training_complete

# Setup environment
device = setup_environment(cfg)

# Get checkpoint info
output_dir, checkpoint = get_checkpoint_info(cfg, num_epochs=cfg.training.num_epochs)

# Instantiate everything
bridge = hydra.utils.instantiate(cfg.bridge)
dataset = hydra.utils.instantiate(cfg.dataset)
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])
    

22:30:07 - INFO - Using device: cuda


22:30:09 - INFO - Found checkpoint: /orcd/data/omarabu/001/njwfish/counting_flows/outputs/61be507a6200/model.pt


Train split: 931, Test split: 49
Train cells: 1202750, Test cells: 61517


  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>

In [14]:
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])

<All keys matched successfully>

In [15]:
import numpy as np
import torch

ctx_window_size    = 196_608
target_window_size = 896
stride             = 896   # tiles exactly; use <= 896 to allow overlap

whalf = ctx_window_size // 2
thalf = target_window_size // 2

base_idx = dataset.base_individual_idx  # (n_cells,)
device = "cuda"  # or "cpu"


gene_counts = {}
for ci, gs, ge, gn in zip(dataset._eligible_cidx, dataset._eligible_start, dataset._eligible_end, dataset._eligible_gene_name):

    chrom_beg  = dataset.chrom_starts[ci]
    gene_start = int(chrom_beg + int(gs))
    gene_end   = int(chrom_beg + int(ge))  # end-exclusive
    gene_len   = int(gene_end - gene_start)

    print(gene_len)

    big_seq = dataset.get_seq(gene_start - whalf, gene_end + whalf)
    assert len(big_seq) == gene_len + 2*whalf

    # Centers whose 896bp target is fully inside the gene
    first_center = thalf
    last_center  = gene_len - thalf
    # if last_center < first_center:
    #     continue  # gene shorter than target; skip or handle separately

    centers_local = np.arange(first_center, last_center + 1, stride, dtype=np.int64)
    if len(centers_local) == 0:
        centers_local = np.array([first_center])
    # ensure we always include a tail center that lands exactly on the end
    if centers_local[-1] != last_center and centers_local[-1] != first_center:
        centers_local = np.append(centers_local, last_center)

    # Precompute per-cell coverage over the whole gene once; we’ll slice per window
    # (shape: [n_cells, gene_len])
    # gene_end_extended shuold be a multiple of 896
    gene_end_extended = gene_start + ((gene_len - 1) // 896 + 1) * 896

    # Output buffer: counts per cell along the gene
    count_arr = np.zeros((1, gene_len), dtype=np.int32)

    for c in centers_local:
        # Context sequence for this center
        seq = big_seq[c : c + ctx_window_size]
        # Target genomic slice indices in gene-local coords
        g0 = int(c - thalf)
        g1 = g0 + target_window_size


        # Build context dict for your model
        context = {
            'seq': torch.from_numpy(seq).to(device),
            'class_emb': torch.zeros(
                1, dataset.target_cond.shape[1], device=device
            )
        }

        # ---- your model call (unchanged) ----
        with torch.no_grad():
            target_counts = bridge.sampler(
                torch.zeros(1, 896, device=device),
                context,
                (avg_model.module if avg_model is not None else model).to(device),
                n_steps=3,
            )  # expect [n_cells, 896]

        # Write into the gene buffer
        count_arr[:, g0:g1] = target_counts.to("cpu").numpy()[:, :gene_len]

    gene_counts[gn] = count_arr

# save gene_counts with pickle
import pickle as pkl
pkl.dump(gene_counts, open("results/expr/gene_counts_baseline.pkl", "wb"))


523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127


In [None]:
import numpy as np
import torch

ctx_window_size    = 196_608
target_window_size = 896
stride             = 896   # tiles exactly; use <= 896 to allow overlap

whalf = ctx_window_size // 2
thalf = target_window_size // 2

base_idx = dataset.base_individual_idx  # (n_cells,)
device = "cuda"  # or "cpu"


gene_counts = {}
for cell_type in range(dataset.target_cond.shape[1]):
    gene_counts[cell_type] = {}
    for ci, gs, ge, gn in zip(dataset._eligible_cidx, dataset._eligible_start, dataset._eligible_end, dataset._eligible_gene_name):

        chrom_beg  = dataset.chrom_starts[ci]
        gene_start = int(chrom_beg + int(gs))
        gene_end   = int(chrom_beg + int(ge))  # end-exclusive
        gene_len   = int(gene_end - gene_start)

        print(gene_len)

        big_seq = dataset.get_seq(gene_start - whalf, gene_end + whalf)
        assert len(big_seq) == gene_len + 2*whalf

        # Centers whose 896bp target is fully inside the gene
        first_center = thalf
        last_center  = gene_len - thalf
        # if last_center < first_center:
        #     continue  # gene shorter than target; skip or handle separately

        centers_local = np.arange(first_center, last_center + 1, stride, dtype=np.int64)
        if len(centers_local) == 0:
            centers_local = np.array([first_center])
        # ensure we always include a tail center that lands exactly on the end
        if centers_local[-1] != last_center and centers_local[-1] != first_center:
            centers_local = np.append(centers_local, last_center)

        # Precompute per-cell coverage over the whole gene once; we’ll slice per window
        # (shape: [n_cells, gene_len])
        # gene_end_extended shuold be a multiple of 896
        gene_end_extended = gene_start + ((gene_len - 1) // 896 + 1) * 896

        # Output buffer: counts per cell along the gene
        count_arr = np.zeros((base_idx.shape[0], gene_len), dtype=np.int32)

        for c in centers_local:
            # Context sequence for this center
            seq = big_seq[c : c + ctx_window_size]
            # Target genomic slice indices in gene-local coords
            g0 = int(c - thalf)
            g1 = g0 + target_window_size


            context = {
                'seq': torch.from_numpy(seq).to(device),
                'class_emb': torch.zeros(
                    1, dataset.target_cond.shape[1], device=device
                )
            }

            context['class_emb'][:, cell_type] = 1

            # ---- your model call (unchanged) ----
            with torch.no_grad():
                target_counts = bridge.sampler(
                    torch.zeros(1, 896, device=device),
                    context,
                    (avg_model.module if avg_model is not None else model).to(device),
                    n_steps=3,
                )  # expect [n_cells, 896]

            # Write into the gene buffer
            count_arr[:, g0:g1] = target_counts.to("cpu").numpy()[:, :gene_len]

        gene_counts[cell_type][gn] = count_arr

# save gene_counts with pickle
import pickle as pkl
pkl.dump(gene_counts, open("results/expr/cell_type_gene_counts_baseline.pkl", "wb"))

523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
7814
28181
4671
6729
4514
2784
9089
7508
21570
5127
523
7098
26277
5077
7105
36280
37494
2396
21619
4522
78