In [1]:
%reload_ext autoreload
%autoreload 2


In [2]:
import hydra
from hydra.core.global_hydra import GlobalHydra

GlobalHydra.instance().clear()
# setup hydra config global for loading this notebook
hydra.initialize(config_path="configs", version_base=None)
cfg = hydra.compose(config_name="merfish_deconv")


In [3]:
from main import setup_environment, get_checkpoint_info, is_training_complete

# Setup environment
device = setup_environment(cfg)

# Get checkpoint info
output_dir, checkpoint = get_checkpoint_info(cfg, num_epochs=cfg.training.num_epochs)

# Instantiate everything
bridge = hydra.utils.instantiate(cfg.bridge)
dataset = hydra.utils.instantiate(cfg.dataset)
model = hydra.utils.instantiate(cfg.model)
model.load_state_dict(checkpoint['model_state_dict'])
avg_model = hydra.utils.instantiate(cfg.averaging, model=model)
avg_model.load_state_dict(checkpoint['avg_model_state_dict'])
    

21:53:38 - INFO - Using device: cuda
21:53:40 - INFO - Using compile: False
21:53:40 - INFO - Found checkpoint: /orcd/data/omarabu/001/njwfish/counting_flows/outputs/fd5d7615cb7a/model.pt


Found 4358 groups
Found 77435 images that are smaller than or equal to 256
Found 4790 groups
Found 88483 images that are smaller than or equal to 256
Found 9148 groups
Found 165918 images that are smaller than or equal to 256
Padding images to {img_size} x {img_size}
Processing images completed
attention mode is flash


<All keys matched successfully>

In [21]:
import torch
from torch.utils.data import DataLoader
from training_deconv import sparse_aggregation_collate_fn


max_batch_size = 10_000

dl = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
    collate_fn=lambda x: sparse_aggregation_collate_fn(x, max_batch_size)
)

In [22]:
from tqdm import tqdm
from training_deconv import deconv_sample_batch

device = 'cuda'

x_0_generated = []
with torch.no_grad():
    for batch in tqdm(dl):
        x_1, context, sampler_kwargs = deconv_sample_batch(batch, device=device, condition_on_end_time=True)
        x_0_generated.append(
            bridge.sampler(
                x_1, context, avg_model.module.to(device) if avg_model is not None else model, 
                **sampler_kwargs
            )
        )
        if x_1['counts'].shape[0] == 10_000:
            raise Exception("Max batch size reached")

  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [40:08<00:00, 16.84s/it]


In [25]:
!mkdir -p results/merfish

In [None]:
# save x_0_generated
import pickle
with open('results/merfish/x_0_generated.pkl', 'wb') as f:
    pickle.dump(x_0_generated, f)
