In [1]:
!nvidia-smi

Thu Jan 20 18:57:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:1B:00.0 Off |                  N/A |
| 29%   37C    P8     1W / 250W |   5955MiB / 11016MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:1C:00.0 Off |                  N/A |
| 43%   70C    P2   263W / 250W |   8251MiB / 11019MiB |     93%      Default |
|       

In [2]:
device = 'cuda:4'

In [3]:
import time
import os, shutil
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from nits.pixelcnn_model import *
from nits.model import NITS, ConditionalNITS
from nits.discretized_mol import discretized_nits_loss, nits_sample
from PIL import Image

import matplotlib.pyplot as plt

def list_str_to_list(s):
    print(s)
    assert s[0] == '[' and s[-1] == ']'
    s = s[1:-1]
    s = s.replace(' ', '')
    s = s.split(',')

    s = [int(x) for x in s]

    return s

parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-g', '--gpu', type=str,
                    default='', help='Location for the dataset')
parser.add_argument('-i', '--data_dir', type=str,
                    default='/data/pixelcnn/data/', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='/data/pixelcnn/models/',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be cifar / mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Restore training from previous model checkpoint?')

# pixelcnn model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=(1 - 5e-6),
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=16,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')

# nits model
parser.add_argument('-a', '--nits_arch', type=list_str_to_list, default='[8,8,1]',
                    help='Architecture of NITS model')
parser.add_argument('-nb', '--nits_bound', type=float, default=5.,
                    help='Upper and lower bound of NITS model')
parser.add_argument('-c', '--constraint', type=str, default='neg_exp',
                    help='Upper and lower bound of NITS model')
parser.add_argument('-fc', '--final_constraint', type=str, default='softmax',
                    help='Upper and lower bound of NITS model')


args = parser.parse_args(['-a', '[16,10,1]', '-g', '4'])

device = 'cuda:' + args.gpu if args.gpu else 'cpu'
print('device:', device)

# HOUSEKEEPING

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'lr_{:.5f}_nr_resnet{}_nr_filters{}_nits_arch{}_constraint{}_final_constraint{}'.format(
    args.lr, args.nr_resnet, args.nr_filters, args.nits_arch, args.constraint, args.final_constraint)
if os.path.exists(os.path.join('runs_test', model_name)):
    shutil.rmtree(os.path.join('runs_test', model_name))

sample_batch_size = 25
obs = (1, 28, 28) if 'mnist' in args.dataset else (3, 32, 32)
input_channels = obs[0]
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5
kwargs = {'num_workers':1, 'pin_memory':True, 'drop_last':True}
ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])

if 'mnist' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
elif 'cifar' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
        download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
else :
    raise Exception('{} dataset not in {mnist, cifar10}'.format(args.dataset))


# INITIALIZE NITS MODEL
if 'mnist' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = NITS(d=1, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                      A_constraint=args.constraint, arch=arch,
                      final_layer_constraint=args.final_constraint).to(device)
elif 'cifar' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = ConditionalNITS(d=3, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                                 A_constraint=args.constraint, arch=arch, autoregressive=True,
                                 pixelrnn=False, normalize_inverse=True,
                                 final_layer_constraint=args.final_constraint,
                                 softmax_temperature=True).to(device)
tot_params = nits_model.tot_params
loss_op = lambda real, params: discretized_nits_loss(real, params, nits_model)
sample_op = lambda params: nits_sample(params, nits_model)

# INITIALIZE PIXELCNN MODEL
model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
                 input_channels=input_channels, nr_logistic_mix=tot_params, num_mix=1)
model = model.to(device)

if args.load_params:
    load_part_of_model(model, args.load_params)
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)
test_losses = []

def sample(model):
    model.train(False)
    with torch.no_grad():
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.to(device)
        for i in range(obs[1]):
            for j in range(obs[2]):
                data_v = Variable(data)
                out   = model(data_v, sample=True)
                out_sample = sample_op(out)
                data[:, :, i, j] = out_sample.data[:, :, i, j]
        return data


[16,10,1]
device: cuda:4
Files already downloaded and verified


In [None]:
print('starting training')
writes = 0
for epoch in range(args.max_epochs):
    model.train(True)
    torch.cuda.synchronize()
    train_loss = 0.
    time_ = time.time()
    model.train()
    for batch_idx, (input,_) in enumerate(train_loader):
        input = input.to(device)
        input = Variable(input)
        output = model(input)
        output.retain_grad()
        loss = loss_op(input, output)
        optimizer.zero_grad()
        loss.backward()
        if output.grad.isnan().any():
            print('output grad are nan')
            break
        optimizer.step()
        train_loss += loss.detach().cpu().numpy()
        if (batch_idx +1) % args.print_every == 0 :
            deno = args.print_every * args.batch_size * np.prod(obs) * np.log(2.)
            print('loss : {:.4f}, time : {:.4f}'.format(
                (train_loss / deno),
                (time.time() - time_)))
            train_loss = 0.
            writes += 1
            time_ = time.time()

    if loss.isnan() or loss.isinf() or output.grad.isnan().any():
        break

    # decrease learning rate
    scheduler.step()

    torch.cuda.synchronize()
    model.eval()
    test_loss = 0.
    for batch_idx, (input,_) in enumerate(test_loader):
        input = input.to(device)
        input_var = Variable(input)
        output = model(input_var)
        loss = loss_op(input_var, output)
        test_loss += loss.detach().cpu().numpy()
        del loss, output

    deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.)
    print('test loss : {:4f}, lr : {:4e}'.format(test_loss / deno, optimizer.param_groups[0]['lr']))
    test_losses.append(test_loss / deno)

    if (epoch + 1) % args.save_interval == 0:
        torch.save(model.state_dict(), '{}/{}_{}.pth'.format(args.save_dir, model_name, epoch))
        print('sampling...')
        sample_t = sample(model)
        sample_t = rescaling_inv(sample_t)
        utils.save_image(sample_t,'/data/pixelcnn/images/{}_{}.png'.format(model_name, epoch),
                nrow=5, padding=0)

starting training
loss : 3.0668, time : 28.3734
loss : 3.0410, time : 28.3665
loss : 3.0675, time : 28.4162
loss : 3.0764, time : 28.4139
loss : 3.0381, time : 28.4378
loss : 3.0558, time : 28.5071
loss : 3.1054, time : 28.4681
loss : 3.1033, time : 28.4500
loss : 3.1288, time : 28.4636
loss : 3.0957, time : 28.5035
loss : 3.1207, time : 28.5214
loss : 3.0875, time : 28.4775
loss : 3.0969, time : 28.4749
loss : 3.0554, time : 28.5275
loss : 3.0987, time : 28.4911
loss : 3.1119, time : 28.4661
loss : 3.0588, time : 28.4763
loss : 3.0877, time : 28.4838
loss : 3.0757, time : 28.4802
loss : 3.1376, time : 28.4779
loss : 3.1525, time : 28.4722
loss : 3.1153, time : 28.4694
loss : 3.0585, time : 28.4835
loss : 3.0988, time : 28.4736
loss : 3.0847, time : 28.4710
loss : 3.0616, time : 28.5035
loss : 3.1251, time : 28.4637
loss : 3.1265, time : 28.4621
loss : 3.1016, time : 28.4645
loss : 3.0890, time : 28.4701
loss : 3.0740, time : 28.5466
loss : 3.1018, time : 28.4863
loss : 3.0500, time : 

loss : 3.0695, time : 28.4736
loss : 3.0755, time : 28.4810
loss : 3.0455, time : 28.4700
loss : 3.0915, time : 28.4845
loss : 3.1149, time : 28.4436
loss : 3.1332, time : 28.4608
loss : 3.0566, time : 28.4380
loss : 3.0387, time : 28.4376
loss : 3.1125, time : 28.4677
loss : 3.0939, time : 28.4634
loss : 3.0283, time : 28.4509
loss : 3.0970, time : 28.4662
loss : 3.0796, time : 28.4497
loss : 3.1181, time : 28.4738
loss : 3.0792, time : 28.4583
loss : 3.0585, time : 28.4550
loss : 3.0974, time : 28.4406
loss : 3.0786, time : 28.4784
loss : 3.0964, time : 28.4677
loss : 3.0538, time : 28.4589
loss : 3.0646, time : 28.4865
loss : 3.0945, time : 28.4842
loss : 3.0470, time : 28.4695
loss : 3.0807, time : 28.4701
loss : 3.0618, time : 28.4543
loss : 3.0230, time : 28.4189
loss : 3.1016, time : 28.4386
loss : 3.0765, time : 28.4472
loss : 3.0440, time : 28.4571
loss : 3.0589, time : 28.4876
loss : 3.0749, time : 28.4529
loss : 3.0566, time : 28.4603
loss : 3.0324, time : 28.4511
loss : 3.0

loss : 3.0655, time : 28.4490
loss : 3.0493, time : 28.4448
loss : 3.0723, time : 28.4341
loss : 3.0636, time : 28.4606
loss : 3.1048, time : 28.4580
loss : 3.0993, time : 28.4562
loss : 3.0683, time : 28.4564
loss : 3.0781, time : 28.4356
loss : 3.0757, time : 28.4408
loss : 3.0745, time : 28.4481
loss : 3.0999, time : 28.4898
loss : 3.1161, time : 28.4586
loss : 3.0619, time : 28.4618
loss : 3.1318, time : 28.4785
loss : 3.0613, time : 28.4643
loss : 3.0242, time : 28.4744
loss : 3.0766, time : 28.4736
loss : 3.1054, time : 28.4459
test loss : 3.068322, lr : 1.999410e-04
loss : 3.0840, time : 28.6445
loss : 3.0597, time : 28.4409
loss : 3.0983, time : 28.3828
loss : 3.0824, time : 28.4237
loss : 3.1313, time : 28.3933
loss : 3.1527, time : 28.3772
loss : 3.0834, time : 28.3575
loss : 3.0864, time : 28.3694
loss : 3.0599, time : 28.3792
loss : 3.1227, time : 28.4285
loss : 3.0882, time : 28.3764
loss : 3.0563, time : 28.3574
loss : 3.0782, time : 28.4354
loss : 3.0772, time : 28.4888


loss : 3.1046, time : 28.4697
loss : 3.0566, time : 28.4750
loss : 3.0790, time : 28.4754
loss : 3.0735, time : 28.4896
loss : 3.1045, time : 28.5058
loss : 3.0454, time : 28.4894
loss : 3.0571, time : 28.5107
loss : 3.0473, time : 28.4768
loss : 3.0796, time : 28.4773
loss : 3.0837, time : 28.4669
loss : 3.0542, time : 28.4760
loss : 3.0649, time : 28.5017
loss : 3.0792, time : 28.4719
loss : 3.0627, time : 28.5741
loss : 3.0606, time : 28.4995
loss : 3.0582, time : 28.5087
loss : 3.0228, time : 28.4904
loss : 3.1069, time : 28.4799
loss : 3.0262, time : 28.4740
loss : 3.0629, time : 28.4520
loss : 3.0026, time : 28.4613
loss : 3.0376, time : 28.4490
loss : 3.0938, time : 28.4589
loss : 3.0713, time : 28.4978
loss : 3.0648, time : 28.4571
loss : 3.0573, time : 28.4597
loss : 3.0457, time : 28.4533
loss : 3.0765, time : 28.5056
loss : 3.0428, time : 28.4987
loss : 3.0852, time : 28.4821
loss : 3.0837, time : 28.4820
loss : 3.0501, time : 28.4878
loss : 3.0681, time : 28.4981
loss : 3.0

loss : 3.0756, time : 28.4674
loss : 3.0588, time : 28.4622
loss : 3.0293, time : 28.4762
loss : 3.0494, time : 28.4724
loss : 3.0796, time : 28.5084
loss : 3.1019, time : 28.4645
loss : 3.0239, time : 28.4714
loss : 3.0484, time : 28.5000
loss : 3.0655, time : 28.4689
loss : 3.0748, time : 28.4732
loss : 3.0242, time : 28.4891
loss : 3.0542, time : 28.4804
loss : 3.0718, time : 28.4790
loss : 3.0599, time : 28.4737
loss : 3.0599, time : 28.4844
loss : 3.0464, time : 28.4897
loss : 3.0436, time : 28.5169
loss : 3.0709, time : 28.4844
loss : 3.0621, time : 28.4877
loss : 3.0893, time : 28.5037
loss : 3.0795, time : 28.4810
loss : 3.0516, time : 28.4666
loss : 3.0476, time : 28.4652
loss : 3.0759, time : 28.4601
loss : 3.0552, time : 28.4689
loss : 3.0559, time : 28.4711
loss : 3.0751, time : 28.4871
loss : 3.0317, time : 28.4613
loss : 3.0271, time : 28.4996
loss : 3.0788, time : 28.4947
loss : 3.0651, time : 28.4857
loss : 3.0340, time : 28.4704
loss : 3.0477, time : 28.4894
loss : 3.0

loss : 3.0861, time : 28.4701
loss : 3.0673, time : 28.4708
loss : 3.0536, time : 28.4826
loss : 3.0520, time : 28.4894
loss : 3.0357, time : 28.4850
loss : 3.0598, time : 28.4791
loss : 3.0636, time : 28.4763
loss : 3.0636, time : 28.4756
loss : 3.0695, time : 28.5062
loss : 3.0307, time : 28.4696
loss : 3.0374, time : 28.4949
loss : 3.0418, time : 28.4972
loss : 3.0194, time : 28.5590
loss : 3.0567, time : 28.5009
loss : 3.0296, time : 28.4944
loss : 3.0394, time : 28.4732
loss : 3.0938, time : 28.4507
loss : 3.0521, time : 28.4564
test loss : 3.031682, lr : 1.999280e-04
loss : 3.0122, time : 28.6146
loss : 3.0515, time : 28.4985
loss : 3.0835, time : 28.4557
loss : 3.0654, time : 28.4551
loss : 3.0495, time : 28.4537
loss : 3.0659, time : 28.4747
loss : 3.0479, time : 28.4661
loss : 3.0406, time : 28.4651
loss : 3.0794, time : 28.4581
loss : 3.0686, time : 28.4725
loss : 3.0526, time : 28.4733
loss : 3.0209, time : 28.4785
loss : 3.0288, time : 28.4692
loss : 3.0412, time : 28.4943


loss : 3.0705, time : 28.5449
loss : 3.0524, time : 28.4896
loss : 3.0357, time : 28.4644
loss : 3.0154, time : 28.4839
loss : 3.0005, time : 28.4801
loss : 3.0108, time : 28.4738
loss : 3.0682, time : 28.4769
loss : 3.0588, time : 28.5004
loss : 3.0619, time : 28.4931
loss : 3.0471, time : 28.4966
loss : 3.0322, time : 28.5057
loss : 3.0350, time : 28.5054
loss : 3.0851, time : 28.4679
loss : 3.0189, time : 28.5072
loss : 3.0658, time : 28.4687
loss : 3.0390, time : 28.4872
loss : 3.0196, time : 28.4987
loss : 3.0864, time : 28.4961
loss : 3.0088, time : 28.4875
loss : 3.0023, time : 28.5413
loss : 3.0272, time : 28.4604
loss : 3.0207, time : 28.4531
loss : 3.0664, time : 28.4556
loss : 3.0506, time : 28.4683
loss : 3.0898, time : 28.4758
loss : 3.0736, time : 28.5087
loss : 3.0420, time : 28.4543
loss : 3.0487, time : 28.4816
loss : 3.0873, time : 28.4693
loss : 3.0640, time : 28.4955
loss : 3.0745, time : 28.4878
loss : 3.0064, time : 28.4802
loss : 3.0205, time : 28.4952
loss : 3.0

loss : 3.0434, time : 28.5098
loss : 3.0197, time : 28.5066
loss : 3.0282, time : 28.5102
loss : 3.0905, time : 28.4778
loss : 3.0522, time : 28.4977
loss : 3.0160, time : 28.5478
loss : 3.0450, time : 28.5019
loss : 3.0302, time : 28.5030
loss : 3.0397, time : 28.4595
loss : 3.0475, time : 28.4605
loss : 2.9900, time : 28.4599
loss : 3.0364, time : 28.4458
loss : 3.0197, time : 28.4641
loss : 3.0412, time : 28.4514
loss : 3.0495, time : 28.4664
loss : 3.0144, time : 28.4583
loss : 3.0586, time : 28.4516
loss : 3.0671, time : 28.5004
loss : 3.0540, time : 28.4719
loss : 3.0644, time : 28.4800
loss : 3.0264, time : 28.4868
loss : 3.1150, time : 28.4686
loss : 3.0240, time : 28.4735
loss : 3.0470, time : 28.4663
loss : 3.0594, time : 28.4668
loss : 3.0257, time : 28.4758
loss : 3.0044, time : 28.4851
loss : 3.0677, time : 28.4761
loss : 3.0370, time : 28.4905
loss : 3.0490, time : 28.5305
loss : 2.9992, time : 28.5005
loss : 3.0382, time : 28.4976
loss : 3.0319, time : 28.4851
loss : 3.0

loss : 3.0349, time : 28.4709
loss : 3.0802, time : 28.4833
loss : 3.0469, time : 28.4780
loss : 3.0663, time : 28.4950
loss : 3.0343, time : 28.4959
loss : 3.0718, time : 28.4836
loss : 3.0405, time : 28.4876
loss : 3.0286, time : 28.4808
loss : 3.0101, time : 28.4823
loss : 3.0212, time : 28.5165
loss : 3.0192, time : 28.5016
loss : 3.0062, time : 28.4974
loss : 3.0485, time : 28.4742
loss : 3.0235, time : 28.4750
loss : 2.9881, time : 28.4893
loss : 3.0036, time : 28.4812
loss : 3.0862, time : 28.4820
loss : 3.0247, time : 28.4837
loss : 3.0132, time : 28.4818
loss : 3.0114, time : 28.4902
loss : 3.0275, time : 28.4842
test loss : 3.001467, lr : 1.999150e-04
loss : 3.0211, time : 28.6422
loss : 3.0384, time : 28.5110
loss : 3.0339, time : 28.4465
loss : 3.0308, time : 28.4615
loss : 3.0159, time : 28.4527
loss : 3.0649, time : 28.4467
loss : 3.0516, time : 28.4479
loss : 3.0379, time : 28.4492
loss : 3.0547, time : 28.4474
loss : 3.0950, time : 28.4466
loss : 3.0235, time : 28.4522


loss : 3.0173, time : 28.5988
test loss : 3.005449, lr : 1.999110e-04
loss : 3.0416, time : 28.8112
loss : 3.0102, time : 28.5168
loss : 3.0119, time : 28.5208
loss : 3.0238, time : 28.5102
loss : 3.0528, time : 28.5305
loss : 3.0130, time : 28.5246
loss : 2.9962, time : 28.5238
loss : 3.0109, time : 28.5301
loss : 3.0345, time : 28.5365
loss : 3.0322, time : 28.4966
loss : 3.0444, time : 28.5106
loss : 3.0260, time : 28.5102
loss : 3.0163, time : 28.5524
loss : 3.0523, time : 28.5194
loss : 3.0413, time : 28.5384
loss : 3.0066, time : 28.5086
loss : 3.0341, time : 28.5257
loss : 3.0239, time : 28.5235
loss : 3.0684, time : 28.5157
loss : 3.0440, time : 28.5213
loss : 3.0386, time : 28.5202
loss : 2.9942, time : 28.5014
loss : 3.0362, time : 28.4995
loss : 3.0250, time : 28.5165
loss : 3.0517, time : 28.5433
loss : 3.0610, time : 28.5188
loss : 3.0106, time : 28.4962
loss : 2.9979, time : 28.5079
loss : 2.9992, time : 28.5101
loss : 3.0099, time : 28.5014
loss : 3.0459, time : 28.4934


loss : 3.0140, time : 28.5176
loss : 3.0502, time : 28.5075
loss : 3.0021, time : 28.4892
loss : 3.0067, time : 28.5081
loss : 3.0461, time : 28.5058
loss : 3.0036, time : 28.5165
loss : 3.0354, time : 28.5083
loss : 3.0667, time : 28.5459
loss : 2.9799, time : 28.5090
loss : 2.9832, time : 28.5074
loss : 3.0236, time : 28.5060
loss : 3.0421, time : 28.5092
loss : 3.0306, time : 28.5014
loss : 2.9871, time : 28.5133
loss : 3.0035, time : 28.5007
loss : 3.0152, time : 28.5252
loss : 3.0411, time : 28.5470
loss : 3.0387, time : 28.5188
loss : 3.0318, time : 28.5113
loss : 3.0203, time : 28.5308
loss : 3.0592, time : 28.5040
loss : 3.0144, time : 28.5173
loss : 2.9840, time : 28.5008
loss : 3.0338, time : 28.4989
loss : 3.0193, time : 28.4953
loss : 3.0193, time : 28.4908
loss : 3.0122, time : 28.4958
loss : 3.1005, time : 28.5107
loss : 3.0449, time : 28.5197
loss : 2.9981, time : 28.5046
loss : 3.0226, time : 28.5179
loss : 3.0168, time : 28.5205
loss : 3.0465, time : 28.4944
loss : 3.0

loss : 3.0028, time : 28.5191
loss : 3.0605, time : 28.5164
loss : 3.0186, time : 28.5101
loss : 3.0142, time : 28.5204
loss : 3.0189, time : 28.5163
loss : 3.0174, time : 28.5290
loss : 2.9845, time : 28.5246
loss : 3.0249, time : 28.5265
loss : 3.0344, time : 28.5194
loss : 3.0341, time : 28.5167
loss : 3.0542, time : 28.5108
loss : 2.9997, time : 28.5529
loss : 3.0363, time : 28.5135
loss : 3.0548, time : 28.5146
loss : 3.0232, time : 28.5588
loss : 2.9740, time : 28.5190
loss : 2.9981, time : 28.5262
loss : 3.0171, time : 28.5141
loss : 3.0520, time : 28.5377
loss : 3.0324, time : 28.5295
loss : 3.0404, time : 28.5166
loss : 3.0500, time : 28.5014
loss : 3.0131, time : 28.4880
loss : 3.0362, time : 28.5153
test loss : 3.005808, lr : 1.999020e-04
loss : 3.0032, time : 28.6585
loss : 3.0518, time : 28.5327
loss : 3.0382, time : 28.5267
loss : 3.0298, time : 28.5592
loss : 3.0475, time : 28.5229
loss : 3.0370, time : 28.4832
loss : 2.9769, time : 28.4851
loss : 2.9666, time : 28.4686


loss : 3.0178, time : 28.5237
loss : 3.0332, time : 28.5195
loss : 3.0379, time : 28.5245
loss : 2.9842, time : 28.5584
test loss : 2.992405, lr : 1.998980e-04
loss : 3.0361, time : 28.7323
loss : 3.0232, time : 28.5270
loss : 3.0290, time : 28.5110
loss : 3.0166, time : 28.5002
loss : 3.0321, time : 28.5039
loss : 2.9832, time : 28.5033
loss : 3.0393, time : 28.5162
loss : 3.0229, time : 28.5048
loss : 2.9857, time : 28.5191
loss : 3.0320, time : 28.5243
loss : 3.0414, time : 28.5502
loss : 2.9985, time : 28.5029
loss : 3.0345, time : 28.5583
loss : 3.0755, time : 28.5226
loss : 3.0579, time : 28.5054
loss : 2.9983, time : 28.5208
loss : 3.0085, time : 28.5129
loss : 3.0168, time : 28.5290
loss : 3.0527, time : 28.4905
loss : 2.9694, time : 28.4933
loss : 3.0031, time : 28.4943
loss : 3.0083, time : 28.4906
loss : 3.0100, time : 28.5161
loss : 3.0063, time : 28.4922
loss : 3.0137, time : 28.4902
loss : 3.0126, time : 28.4863
loss : 3.0008, time : 28.4821
loss : 3.0050, time : 28.4923


loss : 2.9793, time : 28.5286
loss : 3.0313, time : 28.5286
loss : 3.0184, time : 28.5251
loss : 2.9953, time : 28.5510
loss : 3.0151, time : 28.5423
loss : 2.9695, time : 28.5401
loss : 2.9315, time : 28.5265
loss : 2.9816, time : 28.5181
loss : 3.0047, time : 28.5291
loss : 2.9757, time : 28.5115
loss : 3.0483, time : 28.5288
loss : 3.0088, time : 28.5340
loss : 3.0226, time : 28.5420
loss : 3.0025, time : 28.5389
loss : 2.9920, time : 28.5446
loss : 3.0024, time : 28.5711
loss : 3.0085, time : 28.5106
loss : 3.0215, time : 28.5435
loss : 2.9950, time : 28.5426
loss : 3.0208, time : 28.5281
loss : 2.9709, time : 28.5343
loss : 3.0249, time : 28.5570
loss : 3.0043, time : 28.5197
loss : 3.0102, time : 28.5170
loss : 3.0212, time : 28.5086
loss : 2.9840, time : 28.5259
loss : 2.9814, time : 28.5187
loss : 2.9921, time : 28.5653
loss : 3.0123, time : 28.5308
loss : 3.0127, time : 28.5385
loss : 2.9963, time : 28.5223
loss : 2.9510, time : 28.5212
loss : 2.9937, time : 28.5189
loss : 3.0

loss : 2.9744, time : 28.5492
loss : 2.9851, time : 28.5402
loss : 2.9775, time : 28.5206
loss : 2.9585, time : 28.5209
loss : 3.0119, time : 28.5423
loss : 2.9815, time : 28.5371
loss : 2.9802, time : 28.5338
loss : 3.0055, time : 28.5721
loss : 2.9668, time : 28.5356
loss : 3.0467, time : 28.5302
loss : 3.0297, time : 28.5455
loss : 3.0127, time : 28.5494
loss : 3.0140, time : 28.5401
loss : 3.0066, time : 28.5377
loss : 2.9998, time : 28.5328
loss : 2.9912, time : 28.5491
test loss : 2.982928, lr : 1.998660e-04
loss : 2.9670, time : 28.6996
loss : 2.9719, time : 28.5474
loss : 2.9858, time : 28.5937
loss : 2.9647, time : 28.5347
loss : 3.0039, time : 28.5307
loss : 2.9960, time : 28.5418
loss : 3.0216, time : 28.5341
loss : 3.0353, time : 28.5310
loss : 2.9903, time : 28.5346
loss : 2.9903, time : 28.5422
loss : 3.0125, time : 28.5359
loss : 3.0240, time : 28.5281
loss : 2.9989, time : 28.5321
loss : 2.9421, time : 28.5332
loss : 3.0095, time : 28.5691
loss : 3.0029, time : 28.5313


loss : 3.0304, time : 28.5340
loss : 3.0077, time : 28.5340
loss : 2.9622, time : 28.5335
loss : 3.0100, time : 28.5276
loss : 2.9859, time : 28.5391
loss : 2.9577, time : 28.5624
loss : 2.9943, time : 28.5526
loss : 3.0244, time : 28.5471
loss : 2.9996, time : 28.5339
loss : 3.0028, time : 28.5350
loss : 3.0190, time : 28.5194
loss : 3.0237, time : 28.5612
loss : 2.9904, time : 28.5394
loss : 2.9994, time : 28.5311
loss : 2.9970, time : 28.5232
loss : 2.9747, time : 28.5348
loss : 2.9993, time : 28.5393
loss : 2.9856, time : 28.5432
loss : 2.9921, time : 28.5385
loss : 2.9814, time : 28.5438
loss : 3.0111, time : 28.5405
loss : 2.9968, time : 28.5382
loss : 3.0076, time : 28.5568
loss : 3.0355, time : 28.5662
loss : 2.9902, time : 28.5739
loss : 2.9931, time : 28.5580
loss : 2.9540, time : 28.5523
loss : 2.9648, time : 28.5549
loss : 2.9684, time : 28.5389
loss : 3.0198, time : 28.5560
loss : 3.0071, time : 28.5297
loss : 2.9935, time : 28.5424
loss : 2.9506, time : 28.5415
loss : 3.0

loss : 2.9344, time : 28.5296
loss : 2.9734, time : 28.5331
loss : 2.9724, time : 28.5646
loss : 2.9771, time : 28.5535
loss : 3.0270, time : 28.5857
loss : 2.9420, time : 28.5331
loss : 2.9758, time : 28.5298
loss : 3.0152, time : 28.5312
loss : 2.9641, time : 28.5361
loss : 3.0062, time : 28.5413
loss : 2.9836, time : 28.5496
loss : 3.0061, time : 28.5404
loss : 2.9821, time : 28.5325
loss : 2.9854, time : 28.5259
loss : 3.0222, time : 28.5226
loss : 2.9902, time : 28.5234
loss : 2.9853, time : 28.5687
loss : 2.9891, time : 28.5242
loss : 2.9803, time : 28.5287
loss : 2.9651, time : 28.5172
loss : 2.9742, time : 28.5468
loss : 2.9547, time : 28.5348
loss : 3.0212, time : 28.5260
loss : 3.0004, time : 28.5455
loss : 3.0315, time : 28.5299
loss : 3.0002, time : 28.5234
loss : 3.0014, time : 28.5090
loss : 3.0096, time : 28.5197
loss : 3.0078, time : 28.5537
loss : 3.0116, time : 28.5166
loss : 3.0138, time : 28.5153
loss : 2.9773, time : 28.5061
loss : 3.0227, time : 28.5183
loss : 2.9

In [None]:
# torch.save(model.to('cpu').state_dict(), './cifar_model.pth')

In [5]:
test_losses

[3.9946067419982145,
 3.7095731649833557,
 3.5397387527198334,
 3.4285993233309866,
 3.506644070579394,
 3.355817989354682,
 3.332827659111572,
 3.318594494726047,
 3.3263378349837764,
 3.337990469515122,
 3.241161852473759,
 3.2009310717845274,
 3.215256899540375,
 3.2043920480781005,
 3.2180318627635573,
 3.1785560517355735,
 3.192287648196525,
 3.2052902351894015,
 3.181655520497033,
 3.1468223925211043,
 3.1636116053244465,
 3.127488230428158,
 3.184809518048376,
 3.1293779934350847,
 3.1181206981426617,
 3.1670931901373796,
 3.1238606558636377,
 3.116251112596015,
 3.1218776676028988,
 3.094179535355388,
 3.103162447551413,
 3.095869164944386,
 3.1980283989486136,
 3.169104161959829,
 3.0808783839082237,
 3.10101412626816,
 3.097537354029577,
 3.1172043411377937,
 3.073054336381765,
 3.0977771033571444,
 3.064713736758534,
 3.064957556709656,
 3.0588753182349704,
 3.065151132737794,
 3.0964910075394005,
 3.1430258023739324,
 3.0580621464108084,
 3.076979823850915,
 3.0595846253999

In [None]:
input, _ = next(iter(train_loader))

with torch.no_grad():
    output = model(input.to(device))

n_steps = 300
unif_x = torch.linspace(-bounds, bounds, steps=n_steps, device=output.device).reshape(-1, 1).tile(1, 3)
params = output[0,:,10,10].reshape(-1, nits_model.tot_params).tile((n_steps, 1))
pdfs = nits_model.to(unif_x.device).pdf(unif_x, params)

colors = ['red', 'green', 'blue']
for i in range(nits_model.d):
    plt.scatter(unif_x[:,i].detach().cpu(), pdfs[:,i].detach().cpu(), c=colors[i])