In [None]:
import sys
import os

# set the right device
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# NOTE: assuming we are in `ca_body/notebooks`
sys.path.insert(0, '../')
from attrdict import AttrDict

from omegaconf import OmegaConf
from torchvision.utils import make_grid

from ca_body.utils.module_loader import load_from_config
from ca_body.utils.lbs import LBSModule
from ca_body.utils.train import load_checkpoint

device = th.device('cuda:0')

In [None]:
# NOTE: make sure to download the data
model_dir = '../data/cca/PXB184'

ckpt_path = f'{model_dir}/body_dec.ckpt'
config_path = f'{model_dir}/config.yml'
assets_path = f'{model_dir}/static_assets.pt'

# config
config = OmegaConf.load(config_path)
# assets
static_assets = AttrDict(th.load(assets_path))
# sample batch
batch = th.load(f'{model_dir}/sample_batch.pt')
batch = to_device(batch, device)

In [None]:
# building the model
model = load_from_config(
    config.model, 
    assets=static_assets,
).to(device)

# loading model checkpoint
load_checkpoint(
    ckpt_path, 
    modules={'model': model},
    # NOTE: this is accounting for difference in LBS impl
    ignore_names={'model': ['lbs_fn.*']},
)

In [None]:
# disabling training-only stuff
model.learn_blur_enabled = False
model.pixel_cal_enabled = False
model.cal_enabled = False

# forward
with th.no_grad():
    preds = model(**batch)

In [None]:
# visualizing
rgb_preds_grid = make_grid(preds['rgb'], nrow=4).permute(1, 2, 0).cpu().numpy() / 255.
plt.figure(figsize=(15, 15))
plt.imshow(rgb_preds_grid[::4,::4])