In [5]:
import json

import torch
import matplotlib.pyplot as plt
# import seaborn as sns
# sns.set()

from datasets import get_CIFAR10, get_SVHN
from model import Glow

device = torch.device("cpu")

output_folder = '/Users/charumathibadrinath/Downloads/glow/'
model_name = 'glow_affine_coupling.pt'

with open(output_folder + 'hparams.json') as json_file:  
    hparams = json.load(json_file)
    
print(hparams)

image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'], hparams['dataroot'], True)
image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'], hparams['dataroot'], True)

model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'],
             hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes,
             hparams['learn_top'], hparams['y_condition'])

model.load_state_dict(torch.load(output_folder + model_name))
model.set_actnorm_init()

model = model.to(device)

model = model.eval()

{'K': 32, 'L': 3, 'LU_decomposed': True, 'actnorm_scale': 1.0, 'augment': True, 'batch_size': 64, 'cuda': True, 'dataroot': './', 'dataset': 'cifar10', 'download': False, 'epochs': 1500, 'eval_batch_size': 512, 'flow_coupling': 'affine', 'flow_permutation': 'invconv', 'fresh': True, 'hidden_channels': 512, 'learn_top': True, 'lr': 0.0005, 'max_grad_clip': 0, 'max_grad_norm': 0, 'n_init_batches': 8, 'n_workers': 6, 'output_dir': 'output/', 'saved_model': '', 'saved_optimizer': '', 'seed': 0, 'warmup_steps': 4000, 'y_condition': False, 'y_weight': 0.01}
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/CIFAR10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 32134644.22it/s]


Extracting data/CIFAR10/cifar-10-python.tar.gz to data/CIFAR10
Files already downloaded and verified
Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to data/SVHN/train_32x32.mat


100%|██████████| 182040794/182040794 [00:37<00:00, 4862744.02it/s]


Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to data/SVHN/test_32x32.mat


100%|██████████| 64275384/64275384 [00:14<00:00, 4459422.60it/s]
The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:2432.)
  w_init = torch.qr(torch.randn(*w_shape))[0]
LU, pivots = torch.lu(A, compute_pivots)
should be replaced with
LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
and
LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)
should be replaced with
LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:2002.)
  return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
def compute_nll(dataset, model):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, num_workers=6)
    
    nlls = []
    for x,y in dataloader:
        x = x.to(device)
        
        if hparams['y_condition']:
            y = y.to(device)
        else:
            y = None
        
        with torch.no_grad():
            _, nll, _ = model(x, y_onehot=y)
            nlls.append(nll)
        
    return torch.cat(nlls).cpu()

In [None]:
cifar_nll = compute_nll(test_cifar, model)
svhn_nll = compute_nll(test_svhn, model)

print("CIFAR NLL", torch.mean(cifar_nll))
print("SVHN NLL", torch.mean(svhn_nll))

In [None]:
plt.figure(figsize=(20,10))
plt.title("Histogram Glow - trained on CIFAR10")
plt.xlabel("Negative bits per dimension")
plt.hist(-svhn_nll.numpy(), label="SVHN", density=True, bins=30)
plt.hist(-cifar_nll.numpy(), label="CIFAR10", density=True, bins=50)
plt.legend()
plt.show()
# plt.savefig("images/histogram_glow_cifar_svhn.png", dpi=300)