In [1]:
%reload_ext autoreload
%autoreload 2


In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
import hydra
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
# setup hydra config global for loading this notebook
hydra.initialize(config_path="configs", version_base=None)
cfg = hydra.compose(config_name="expr")


In [5]:
from main import setup_environment, get_checkpoint_info, is_training_complete

# Setup environment
device = setup_environment(cfg)

# Get checkpoint info
output_dir, checkpoint = get_checkpoint_info(cfg, num_epochs=cfg.training.num_epochs)

# Instantiate everything
bridge = hydra.utils.instantiate(cfg.bridge)
dataset = hydra.utils.instantiate(cfg.dataset)
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])
    

11:50:38 - INFO - Using device: cuda
11:50:57 - INFO - Found checkpoint: /orcd/data/omarabu/001/njwfish/counting_flows/outputs/84cddc73ed43/model.pt


Train split: 931, Test split: 49
Train cells: 1202750, Test cells: 61517


  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>

In [None]:
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])

<All keys matched successfully>

In [None]:
import numpy as np
import torch

ctx_window_size    = 196_608
target_window_size = 896
stride             = 896   # tiles exactly; use <= 896 to allow overlap

whalf = ctx_window_size // 2
thalf = target_window_size // 2

base_idx = dataset.base_individual_idx  # (n_cells,)
device = "cuda"  # or "cpu"


gene_counts = {}
for ci, gs, ge, gn in zip(dataset._eligible_cidx, dataset._eligible_start, dataset._eligible_end, dataset._eligible_gene_name):

    chrom_beg  = dataset.chrom_starts[ci]
    gene_start = int(chrom_beg + int(gs))
    gene_end   = int(chrom_beg + int(ge))  # end-exclusive
    gene_len   = int(gene_end - gene_start)

    print(gene_len)

    big_seq = dataset.get_seq(gene_start - whalf, gene_end + whalf)
    assert len(big_seq) == gene_len + 2*whalf

    # Centers whose 896bp target is fully inside the gene
    first_center = thalf
    last_center  = gene_len - thalf
    # if last_center < first_center:
    #     continue  # gene shorter than target; skip or handle separately

    centers_local = np.arange(first_center, last_center + 1, stride, dtype=np.int64)
    if len(centers_local) == 0:
        centers_local = np.array([first_center])
    # ensure we always include a tail center that lands exactly on the end
    if centers_local[-1] != last_center and centers_local[-1] != first_center:
        centers_local = np.append(centers_local, last_center)

    # Precompute per-cell coverage over the whole gene once; we’ll slice per window
    # (shape: [n_cells, gene_len])
    # gene_end_extended shuold be a multiple of 896
    gene_end_extended = gene_start + ((gene_len - 1) // 896 + 1) * 896
    print(gene_end - gene_start, gene_end_extended - gene_start, gene_end < gene_end_extended)
    x_1_global = torch.from_numpy(
        dataset.fast_get_overlap_raw(base_idx, (gene_start, gene_end_extended))
    )# .to(device)

    # Output buffer: counts per cell along the gene
    count_arr = np.zeros((base_idx.shape[0], gene_len), dtype=np.int32)

    for c in centers_local:
        # Context sequence for this center
        seq = big_seq[c : c + ctx_window_size]
        # Target genomic slice indices in gene-local coords
        g0 = int(c - thalf)
        g1 = g0 + target_window_size


        # Slice the “conditioning” counts for these positions
        x_1 = x_1_global[:, g0:g1].to(device)  # [n_cells, 896]
        # Build context dict for your model
        context = {
            'seq': torch.from_numpy(seq).to(device),
            'class_emb': torch.zeros(
                base_idx.shape[0], dataset.target_cond.shape[1], device=device
            )
        }

        # ---- your model call (unchanged) ----
        with torch.no_grad():
            target_counts = bridge.sampler(
                x_1,
                context,
                (avg_model.module if avg_model is not None else model).to(device),
                n_steps=3,
            )  # expect [n_cells, 896]

        # Write into the gene buffer
        count_arr[:, g0:g1] = target_counts.to("cpu").numpy()[:, :gene_len]

    gene_counts[gn] = count_arr

# save gene_counts with pickle
import pickle as pkl
pkl.dump(gene_counts, open("results/expr/gene_counts.pkl", "wb"))


523
523 896 True
7098
7098 7168 True
26277
26277 26880 True


5077
5077 5376 True
7105
7105 7168 True
36280
36280 36736 True
37494
37494 37632 True
2396
2396 2688 True
21619
21619 22400 True
4522
4522 5376 True
7814
7814 8064 True
28181
28181 28672 True
4671
4671 5376 True
6729
6729 7168 True
4514
4514 5376 True
2784
2784 3584 True
9089
9089 9856 True
7508
7508 8064 True
21570
21570 22400 True
5127
5127 5376 True


In [6]:
import numpy as np
import torch

ctx_window_size    = 196_608
target_window_size = 896
stride             = 896   # tiles exactly; use <= 896 to allow overlap

whalf = ctx_window_size // 2
thalf = target_window_size // 2

base_idx = dataset.base_individual_idx  # (n_cells,)
device = "cuda"  # or "cpu"


gene_counts = {}
for cell_type in range(dataset.target_cond.shape[1]):
    gene_counts[cell_type] = {}
    for ci, gs, ge, gn in zip(dataset._eligible_cidx, dataset._eligible_start, dataset._eligible_end, dataset._eligible_gene_name):

        chrom_beg  = dataset.chrom_starts[ci]
        gene_start = int(chrom_beg + int(gs))
        gene_end   = int(chrom_beg + int(ge))  # end-exclusive
        gene_len   = int(gene_end - gene_start)

        print(gene_len)

        big_seq = dataset.get_seq(gene_start - whalf, gene_end + whalf)
        assert len(big_seq) == gene_len + 2*whalf

        # Centers whose 896bp target is fully inside the gene
        first_center = thalf
        last_center  = gene_len - thalf
        # if last_center < first_center:
        #     continue  # gene shorter than target; skip or handle separately

        centers_local = np.arange(first_center, last_center + 1, stride, dtype=np.int64)
        if len(centers_local) == 0:
            centers_local = np.array([first_center])
        # ensure we always include a tail center that lands exactly on the end
        if centers_local[-1] != last_center and centers_local[-1] != first_center:
            centers_local = np.append(centers_local, last_center)

        # Precompute per-cell coverage over the whole gene once; we’ll slice per window
        # (shape: [n_cells, gene_len])
        # gene_end_extended shuold be a multiple of 896
        gene_end_extended = gene_start + ((gene_len - 1) // 896 + 1) * 896
        print(cell_type, gene_end - gene_start, gene_end_extended - gene_start, gene_end < gene_end_extended)
        x_1_global = torch.from_numpy(
            dataset.fast_get_overlap_raw(base_idx, (gene_start, gene_end_extended))
        )# .to(device)

        # Output buffer: counts per cell along the gene
        count_arr = np.zeros((base_idx.shape[0], gene_len), dtype=np.int32)

        for c in centers_local:
            # Context sequence for this center
            seq = big_seq[c : c + ctx_window_size]
            # Target genomic slice indices in gene-local coords
            g0 = int(c - thalf)
            g1 = g0 + target_window_size


            # Slice the “conditioning” counts for these positions
            x_1 = x_1_global[:, g0:g1].to(device)  # [n_cells, 896]
            # Build context dict for your model
            context = {
                'seq': torch.from_numpy(seq).to(device),
                'class_emb': torch.zeros(
                    base_idx.shape[0], dataset.target_cond.shape[1], device=device
                )
            }

            context['class_emb'][:, cell_type] = 1

            # ---- your model call (unchanged) ----
            with torch.no_grad():
                target_counts = bridge.sampler(
                    x_1,
                    context,
                    (avg_model.module if avg_model is not None else model).to(device),
                    n_steps=3,
                )  # expect [n_cells, 896]

            # Write into the gene buffer
            count_arr[:, g0:g1] = target_counts.to("cpu").numpy()[:, :gene_len]

        gene_counts[cell_type][gn] = count_arr

    # save gene_counts with pickle
    import pickle as pkl
    pkl.dump(gene_counts, open("results/expr/cell_type_gene_counts.pkl", "wb"))

523
0 523 896 True


7098
0 7098 7168 True
26277
0 26277 26880 True
5077
0 5077 5376 True
7105
0 7105 7168 True
36280
0 36280 36736 True
37494
0 37494 37632 True
2396
0 2396 2688 True
21619
0 21619 22400 True
4522
0 4522 5376 True
7814
0 7814 8064 True
28181
0 28181 28672 True
4671
0 4671 5376 True
6729
0 6729 7168 True
4514
0 4514 5376 True
2784
0 2784 3584 True
9089
0 9089 9856 True
7508
0 7508 8064 True
21570
0 21570 22400 True
5127
0 5127 5376 True
523
1 523 896 True
7098
1 7098 7168 True
26277
1 26277 26880 True
5077
1 5077 5376 True
7105
1 7105 7168 True
36280
1 36280 36736 True
37494
1 37494 37632 True
2396
1 2396 2688 True
21619
1 21619 22400 True
4522
1 4522 5376 True
7814
1 7814 8064 True
28181
1 28181 28672 True
4671
1 4671 5376 True
6729
1 6729 7168 True
4514
1 4514 5376 True
2784
1 2784 3584 True
9089
1 9089 9856 True
7508
1 7508 8064 True
21570
1 21570 22400 True
5127
1 5127 5376 True
523
2 523 896 True
7098
2 7098 7168 True
26277
2 26277 26880 True
5077
2 5077 5376 True
7105
2 7105 7168 True