# Do the thing

In [1]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob

In [2]:
root = "/media/nick/hdd021/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

# path to save data
out_path = os.path.join(root, "analyses", "crossmodal", "hotfish", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
fig_path = os.path.join(root, "figures", "crossmodal", "hotfish", "")
os.makedirs(fig_path, exist_ok=True)

train_name = "20241107_ds"
model_name = "ntxent_256_20250504_091524"
train_dir = os.path.join(root, "training_data", train_name, "hydra_outputs", "")
model_dir = os.path.join(train_dir, model_name, "lightning_logs", "")

# get path to model
training_path = sorted(glob(os.path.join(model_dir, "*")))[-1]
training_name = os.path.dirname(training_path)

In [17]:
# load the morph predictions
morph_pd_df = pd.read_csv(os.path.join(out_path, "seq_to_morph_pd.csv"))
latent_cols = [col for col in morph_pd_df.columns if "z_mu" in col]
Z = morph_pd_df.loc[:, latent_cols].to_numpy()

In [7]:
from src.assess.assess_hydra_results import get_hydra_runs, initialize_model_to_asses, parse_hydra_paths
from omegaconf import OmegaConf
from src.lightning.pl_wrappers import LitModel

# load trained model
root = "/media/nick/hdd021/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/training_data/20241107_ds/"
run_name = "ntxent_256_20250504_091524"
hydra_run_path = os.path.join(root, "hydra_outputs", run_name, "")
_, cfg_path_list = get_hydra_runs(hydra_run_path, run_type="name")

cfg = cfg_path_list[0]
config = OmegaConf.load(cfg)

# initialize
model, model_config = initialize_model_to_asses(config)
loss_fn = model_config.lossconfig.create_module()
run_path = os.path.join(os.path.dirname(os.path.dirname(cfg)), "lightning_logs")
model_dir, latest_ckpt = parse_hydra_paths(run_path=run_path)

# get model
lit_model = LitModel.load_from_checkpoint(latest_ckpt,
                                          model=model,
                                          loss_fn=loss_fn,
                                          data_cfg=OmegaConf.create({}),)

lit_model.eval()  # 1) turn off dropout / switch BN to eval
lit_model.freeze()

decoder = lit_model.model.decoder

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/nick/miniconda3/envs/torch-env-plot/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [23]:
import torch

Z_tensor = torch.from_numpy(Z).to(lit_model.device).float()

with torch.no_grad():
    im_recon = decoder(Z_tensor)

In [40]:
from tqdm import tqdm
from PIL import Image

im_stack = im_recon[0].detach().cpu().numpy()  # now shape (C, H, W)
im_path = os.path.join(fig_path, "recon_images_i", "")
os.makedirs(im_path, exist_ok=True)

for m in tqdm(range(10, im_stack.shape[0]-5), "Writing reconstructions to file..."):
# reorder to (H, W, C) for px.imshow
# grayscale
    arr = im_stack[m].squeeze(0)  # (H, W)

    # invert
    arr = 1 - arr
    arr = arr - np.min(arr)
    arr = arr / np.max(arr)

    # scale to 0–255 uint8
    arr = np.clip(arr, 0, 1)

    arr = (255 * arr).astype(np.uint8)
    # arr = 255 - arr
    # arr = arr - np.min
    # make PIL image
    im = Image.fromarray(arr[25:-25], mode="L")

    fname = os.path.join(im_path, f"recon_{m:03d}.png")
    im.save(fname)
    # print(f"Saved {fname}")


Writing reconstructions to file...: 100%|██████████| 85/85 [00:00<00:00, 206.85it/s]


In [39]:
np.max(im)

np.uint8(247)