In [1]:
%reload_ext autoreload
%autoreload 2


In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
import hydra
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
# setup hydra config global for loading this notebook
hydra.initialize(config_path="configs", version_base=None)
cfg = hydra.compose(config_name="expr")


In [None]:
from main import setup_environment, get_checkpoint_info, is_training_complete

# Setup environment
device = setup_environment(cfg)

# Get checkpoint info
output_dir, checkpoint = get_checkpoint_info(cfg, num_epochs=cfg.training.num_epochs)

# Instantiate everything
bridge = hydra.utils.instantiate(cfg.bridge)
dataset = hydra.utils.instantiate(cfg.dataset)
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])
    

11:36:40 - INFO - Using device: cuda


11:36:42 - INFO - Found checkpoint: /orcd/data/omarabu/001/njwfish/counting_flows/outputs/84cddc73ed43/model.pt
  from .autonotebook import tqdm as notebook_tqdm


Train split: 931, Test split: 49
Train cells: 1202750, Test cells: 61517


<All keys matched successfully>

In [None]:
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])

<All keys matched successfully>

In [22]:
model

EnergyScoreLoss(
  (architecture): SCFormer(
    (enformer): Enformer(
      (stem): Sequential(
        (0): Conv1d(4, 768, kernel_size=(15,), stride=(1,), padding=(7,))
        (1): Residual(
          (fn): Sequential(
            (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): GELU()
            (2): Conv1d(768, 768, kernel_size=(1,), stride=(1,))
          )
        )
        (2): AttentionPool(
          (pool_fn): Rearrange('b d (n p) -> b d n p', p=2)
          (to_attn_logits): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
      )
      (conv_tower): Sequential(
        (0): Sequential(
          (0): Sequential(
            (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): GELU()
            (2): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=(2,))
          )
          (1): Residual(
            (fn): Sequential(
              (0): 

In [None]:
import numpy as np
import torch

ctx_window_size    = 196_608
target_window_size = 896
stride             = 896   # tiles exactly; use <= 896 to allow overlap

whalf = ctx_window_size // 2
thalf = target_window_size // 2

base_idx = dataset.base_individual_idx  # (n_cells,)
target_idxs = dataset.test_individual_idxs

device = "cuda"  # or "cpu"


# gene_counts = {}
# gene_counts_true = {}
for target_key in target_idxs:
    if target_key in gene_counts:
        continue
    gene_counts[target_key] = {}
    gene_counts_true[target_key] = {}
    target_idx = target_idxs[target_key]
    try:    
        for ci, gs, ge, gn in zip(dataset._eligible_cidx, dataset._eligible_start, dataset._eligible_end, dataset._eligible_gene_name):

            chrom_beg  = dataset.chrom_starts[ci]
            gene_start = int(chrom_beg + int(gs))
            gene_end   = int(chrom_beg + int(ge))  # end-exclusive
            gene_len   = int(gene_end - gene_start)

            print(gene_len)

            big_seq = dataset.get_seq(gene_start - whalf, gene_end + whalf)
            assert len(big_seq) == gene_len + 2*whalf

            # Centers whose 896bp target is fully inside the gene
            first_center = thalf
            last_center  = gene_len - thalf
            # if last_center < first_center:
            #     continue  # gene shorter than target; skip or handle separately

            centers_local = np.arange(first_center, last_center + 1, stride, dtype=np.int64)
            if len(centers_local) == 0:
                centers_local = np.array([first_center])
            # ensure we always include a tail center that lands exactly on the end
            if centers_local[-1] != last_center and centers_local[-1] != first_center:
                centers_local = np.append(centers_local, last_center)

            # Precompute per-cell coverage over the whole gene once; we’ll slice per window
            # (shape: [n_cells, gene_len])
            # gene_end_extended shuold be a multiple of 896
            gene_end_extended = gene_start + ((gene_len - 1) // 896 + 1) * 896
            print(target_key, gene_end - gene_start, gene_end_extended - gene_start, gene_end < gene_end_extended)
            x_1_global = torch.from_numpy(
                dataset.fast_get_overlap_raw(base_idx[:target_idx.shape[0]], (gene_start, gene_end_extended))
            )# .to(device)

            gene_counts_true[target_key][gn] = dataset.fast_get_overlap_raw(target_idx, (gene_start, gene_end_extended))
            target_sum_global = torch.from_numpy(gene_counts_true[target_key][gn]).sum(dim=0)
            gene_counts_true[target_key][gn] = gene_counts_true[target_key][gn][:,:gene_len]

            # Output buffer: counts per cell along the gene
            count_arr = np.zeros((target_idx.shape[0], gene_len), dtype=np.int32)

            for c in centers_local:
                # Context sequence for this center
                seq = big_seq[c : c + ctx_window_size]
                # Target genomic slice indices in gene-local coords
                g0 = int(c - thalf)
                g1 = g0 + target_window_size


                # Slice the “conditioning” counts for these positions
                x_1 = x_1_global[:, g0:g1].to(device)  # [n_cells, 896]
                # Build context dict for your model
                context = {
                    'seq': torch.from_numpy(seq).to(device),
                    'class_emb': torch.zeros(
                        target_idx.shape[0], dataset.target_cond.shape[1], device=device
                    ),
                    'target_sum': target_sum_global[g0:g1].unsqueeze(0).to(device),
                    'A': torch.ones(1, target_idx.shape[0], device=device).to(device)
                }

                with torch.no_grad():
                    target_counts = bridge.sampler(
                        x_1,
                        context,
                        (avg_model.module if avg_model is not None else model).to(device),
                        n_steps=3,
                    )  # expect [n_cells, 896]

                # Write into the gene buffer
                count_arr[:, g0:g1] = target_counts.to("cpu").numpy()[:, :gene_len]

            gene_counts[target_key][gn] = count_arr
            print("counts match",np.all(count_arr.sum(axis=0) == gene_counts_true[target_key][gn].sum(axis=0)))
    except Exception as e:
        import gc
        torch.cuda.empty_cache()
        gc.collect()
        continue
    # save gene_counts with pickle
    import pickle as pkl
    pkl.dump(gene_counts, open("results/expr/deconv_gene_counts_proj.pkl", "wb"))
    # pkl.dump(gene_counts_true, open("results/expr/deconv_gene_counts_true.pkl", "wb"))

523
794_795 523 896 True
counts match True
7098
794_795 7098 7168 True
counts match True
26277
794_795 26277 26880 True
counts match True
5077
794_795 5077 5376 True
counts match True
7105
794_795 7105 7168 True
counts match True
36280
794_795 36280 36736 True
counts match True
37494
794_795 37494 37632 True
counts match True
2396
794_795 2396 2688 True
counts match True
21619
794_795 21619 22400 True
counts match True
4522
794_795 4522 5376 True
counts match True
7814
794_795 7814 8064 True
counts match True
28181
794_795 28181 28672 True
counts match True
4671
794_795 4671 5376 True
counts match True
6729
794_795 6729 7168 True
counts match True
4514
794_795 4514 5376 True
counts match True
2784
794_795 2784 3584 True
counts match True
9089
794_795 9089 9856 True
counts match True
7508
794_795 7508 8064 True
counts match True
21570
794_795 21570 22400 True
counts match True
5127
794_795 5127 5376 True
counts match True
523
93_93 523 896 True
counts match True
7098
93_93 7098 7168 Tru

In [18]:
target_idx.shape[0]

2674