In [1]:
import train
import numpy as np
import imageio
import os
import time
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from data import set_up_data
from utils import get_cpu_stats_over_ranks
from train_helpers import set_up_hyperparams, load_vaes, load_opt, accumulate_stats, save_model, update_ema

In [None]:
train.main()

In [2]:
def evaluate(H, ema_vae, data_valid, preprocess_fn, logprint):
    stats_valid = []
    valid_sampler = DistributedSampler(data_valid, num_replicas=H.mpi_size, rank=H.rank)

    for x in DataLoader(data_valid, batch_size=H.n_batch, drop_last=True, pin_memory=True, sampler=valid_sampler):
        data_input, target = preprocess_fn(x)
        stats_valid.append(train.eval_step(data_input, target, ema_vae))

    vals = [a['elbo'] for a in stats_valid]
    finites = np.array(vals)[np.isfinite(vals)]
    stats = dict(n_batches=len(vals), filtered_elbo=np.mean(finites), **{k: np.mean([a[k] for a in stats_valid]) for k in stats_valid[-1]})
    return stats

def run_test_eval(H, ema_vae, data_test, preprocess_fn, logprint):
    
    print('evaluating')
    stats = evaluate(H, ema_vae, data_test, preprocess_fn, logprint)
    print('test results')
    for k in stats:
        print(k, stats[k])
    logprint(type='test_loss', **stats)

In [None]:
H, logprint = set_up_hyperparams()
H, data_train, data_valid_or_test, preprocess_fn = set_up_data(H)

H.test_eval=True
H.restore_ema_path='cifar10-seed0-iter-900000-model-ema.th'
vae, ema_vae = load_vaes(H, logprint)   # ema_vae maintains the exponential moving average of the params;


if H.test_eval:
	run_test_eval(H, ema_vae, data_valid_or_test, preprocess_fn, logprint)
else:
	train.train_loop(H, data_train, data_valid_or_test, preprocess_fn, vae, ema_vae, logprint)

In [9]:
viz_batch_original, viz_batch_processed = train.get_sample_for_visualization(data_valid_or_test, preprocess_fn, H.num_images_visualize, H.dataset)
train.write_images(H, ema_vae, viz_batch_original, viz_batch_processed, f'{H.save_dir}/samples-1.png', logprint)

time: Thu Dec 19 05:47:50 2024, message: printing samples to ./saved_models/test/samples-1.png


In [4]:
# debug script

for i,(name,param) in enumerate(vae.named_parameters()):
    print(name)
    print(param[0])   
    print(param.grad[0])
    print(param.device)
    if i==0:
        break
    
optimizer.param_groups[0]['params']


ImportError: /home/mike_1102/miniconda3/envs/env02/lib/python3.11/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /home/mike_1102/miniconda3/envs/env02/lib/python3.11/site-packages/amp_C.cpython-311-x86_64-linux-gnu.so)