In [1]:
!nvidia-smi

Tue Jan 18 02:31:22 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:1B:00.0 Off |                  N/A |
| 29%   37C    P8     1W / 250W |   5955MiB / 11016MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:1C:00.0 Off |                  N/A |
| 44%   71C    P2   259W / 250W |   8249MiB / 11019MiB |     91%      Default |
|       

In [2]:
device = 'cuda:4'

In [3]:
import time
import os, shutil
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from nits.pixelcnn_model import *
from nits.model import NITS, ConditionalNITS
from nits.discretized_mol import discretized_nits_loss, nits_sample
from PIL import Image

import matplotlib.pyplot as plt

def list_str_to_list(s):
    print(s)
    assert s[0] == '[' and s[-1] == ']'
    s = s[1:-1]
    s = s.replace(' ', '')
    s = s.split(',')

    s = [int(x) for x in s]

    return s

parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-g', '--gpu', type=str,
                    default='', help='Location for the dataset')
parser.add_argument('-i', '--data_dir', type=str,
                    default='/data/pixelcnn/data/', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='/data/pixelcnn/models/',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be cifar / mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Restore training from previous model checkpoint?')

# pixelcnn model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=(1 - 5e-6),
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=16,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')

# nits model
parser.add_argument('-a', '--nits_arch', type=list_str_to_list, default='[8, 8, 1]',
                    help='Architecture of NITS model')
parser.add_argument('-nb', '--nits_bound', type=float, default=5.,
                    help='Upper and lower bound of NITS model')
parser.add_argument('-c', '--constraint', type=str, default='neg_exp',
                    help='Upper and lower bound of NITS model')
parser.add_argument('-fc', '--final_constraint', type=str, default='softmax',
                    help='Upper and lower bound of NITS model')


args = parser.parse_args(['-a', '[8,8,1]', '-g', '4', '-fc', 'softmax'])

device = 'cuda:' + args.gpu if args.gpu else 'cpu'
print('device:', device)

# HOUSEKEEPING

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'lr_{:.5f}_nr_resnet{}_nr_filters{}_nits_arch{}_constraint{}_final_constraint{}'.format(
    args.lr, args.nr_resnet, args.nr_filters, args.nits_arch, args.constraint, args.final_constraint)
if os.path.exists(os.path.join('runs_test', model_name)):
    shutil.rmtree(os.path.join('runs_test', model_name))

sample_batch_size = 25
obs = (1, 28, 28) if 'mnist' in args.dataset else (3, 32, 32)
input_channels = obs[0]
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5
kwargs = {'num_workers':1, 'pin_memory':True, 'drop_last':True}
ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])

if 'mnist' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
elif 'cifar' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
        download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
else :
    raise Exception('{} dataset not in {mnist, cifar10}'.format(args.dataset))


# INITIALIZE NITS MODEL
if 'mnist' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = NITS(d=1, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                      A_constraint=args.constraint, arch=arch, final_layer_constraint=args.final_constraint).to(device)
elif 'cifar' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = ConditionalNITS(d=3, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                                 A_constraint=args.constraint, arch=arch, autoregressive=True,
                                 pixelrnn=True, normalize_inverse=True, final_layer_constraint=args.final_constraint).to(device)
tot_params = nits_model.tot_params
loss_op = lambda real, params: discretized_nits_loss(real, params, nits_model)
sample_op = lambda params: nits_sample(params, nits_model)

# INITIALIZE PIXELCNN MODEL
model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
                 input_channels=input_channels, nr_logistic_mix=tot_params, num_mix=1)
model = model.to(device)

if args.load_params:
    load_part_of_model(model, args.load_params)
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)
test_losses = []

def sample(model):
    model.train(False)
    with torch.no_grad():
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.to(device)
        for i in range(obs[1]):
            for j in range(obs[2]):
                data_v = Variable(data)
                out   = model(data_v, sample=True)
                out_sample = sample_op(out)
                data[:, :, i, j] = out_sample.data[:, :, i, j]
        return data


[8,8,1]
device: cuda:4
Files already downloaded and verified
starting training




loss : 7.3243, time : 29.9238
loss : 6.3570, time : 26.0241
loss : 6.0897, time : 26.1377
loss : 5.8779, time : 26.0773
loss : 5.7383, time : 26.0161
loss : 5.6196, time : 26.0925
loss : 5.5704, time : 26.0538
loss : 5.4705, time : 26.0643
loss : 5.4227, time : 26.2111
loss : 5.3895, time : 26.1986
loss : 5.3085, time : 26.0519
loss : 5.2375, time : 26.0255
loss : 5.2448, time : 26.0247
loss : 5.1785, time : 26.0577
loss : 5.1621, time : 26.0832
loss : 5.0807, time : 26.0666
loss : 5.0644, time : 26.0441
loss : 5.0400, time : 26.1670
loss : 5.0332, time : 26.1444
loss : 5.0155, time : 26.0591
loss : 4.9112, time : 26.4112
loss : 4.9483, time : 26.0707
loss : 4.8995, time : 26.0654
loss : 4.8636, time : 26.0548
loss : 4.8114, time : 26.0596
loss : 4.8009, time : 26.2782
loss : 4.8075, time : 26.0521
loss : 4.7955, time : 26.0513
loss : 4.6752, time : 26.0396
loss : 4.7626, time : 26.0414
loss : 4.7484, time : 26.0806
loss : 4.6911, time : 26.0407
loss : 4.6883, time : 26.0514
loss : 4.6

loss : 3.8990, time : 26.0806
loss : 3.8713, time : 26.0737
loss : 3.8917, time : 26.1227
loss : 3.9572, time : 26.0751
loss : 3.9020, time : 26.0692
loss : 3.8986, time : 26.0797
loss : 3.8554, time : 26.0739
loss : 3.9662, time : 26.0978
loss : 3.8788, time : 26.0781
loss : 3.8985, time : 26.0837
loss : 3.8995, time : 26.1655
loss : 3.9043, time : 26.1860
loss : 3.8962, time : 26.1525
loss : 3.8774, time : 26.1477
loss : 3.8965, time : 26.2501
loss : 3.9045, time : 26.1633
loss : 3.9472, time : 26.1978
loss : 3.9037, time : 26.1787
loss : 3.8598, time : 26.2043
loss : 3.8607, time : 26.1939
loss : 3.8260, time : 26.1100
loss : 3.8787, time : 26.0646
loss : 3.8677, time : 26.0321
loss : 3.9411, time : 26.0538
loss : 3.9611, time : 26.0598
loss : 3.8546, time : 26.1076
loss : 3.8798, time : 26.0729
loss : 3.9054, time : 26.0662
loss : 3.8717, time : 26.0538
loss : 3.8533, time : 26.0488
loss : 3.8399, time : 26.0731
loss : 3.8807, time : 26.1517
loss : 3.8468, time : 26.1781
loss : 3.8

loss : 3.7034, time : 26.0958
loss : 3.6778, time : 26.0547
loss : 3.7262, time : 26.0748
loss : 3.7187, time : 26.0987
loss : 3.7296, time : 26.1126
loss : 3.6411, time : 26.1231
loss : 3.6842, time : 26.1251
loss : 3.7595, time : 26.0865
loss : 3.7294, time : 26.0620
loss : 3.7737, time : 26.1162
loss : 3.7666, time : 26.1380
loss : 3.6717, time : 26.2047
loss : 3.7099, time : 26.9456
loss : 3.6679, time : 26.9351
loss : 3.6785, time : 26.2420
loss : 3.6872, time : 26.0667
loss : 3.7302, time : 26.1029
loss : 3.6711, time : 26.0599
loss : 3.7095, time : 26.0629
test loss : 3.608199846125519
loss : 3.7323, time : 26.2422
loss : 3.7495, time : 26.1712
loss : 3.7045, time : 26.1931
loss : 3.6901, time : 26.0521
loss : 3.8006, time : 26.1355
loss : 3.7667, time : 26.1211
loss : 3.7607, time : 26.0507
loss : 3.6567, time : 26.0636
loss : 3.7106, time : 26.1431
loss : 3.7078, time : 26.1115
loss : 3.7229, time : 26.0625
loss : 3.7098, time : 26.0375
loss : 3.7350, time : 26.0356
loss : 3.6

loss : 3.7190, time : 26.0309
loss : 3.6498, time : 26.0311
loss : 3.6357, time : 25.9829
loss : 3.5856, time : 25.9918
loss : 3.5653, time : 26.0018
loss : 3.6547, time : 26.0174
loss : 3.6359, time : 25.9840
loss : 3.6620, time : 25.9906
loss : 3.6030, time : 25.9937
loss : 3.6340, time : 26.0332
loss : 3.6760, time : 26.0010
loss : 3.6602, time : 26.0362
loss : 3.6066, time : 26.0464
loss : 3.5976, time : 26.0854
loss : 3.6175, time : 26.0138
loss : 3.5669, time : 26.2631
loss : 3.6062, time : 26.0515
loss : 3.6197, time : 26.0447
loss : 3.6660, time : 25.9958
loss : 3.6131, time : 26.0009
loss : 3.6437, time : 25.9980
loss : 3.6482, time : 26.0113
loss : 3.6335, time : 25.9804
loss : 3.6202, time : 26.0362
loss : 3.5916, time : 26.0494
loss : 3.6114, time : 26.0198
loss : 3.6376, time : 26.0071
loss : 3.5981, time : 25.9812
loss : 3.6362, time : 26.0715
loss : 3.6568, time : 26.0216
loss : 3.6172, time : 26.0252
loss : 3.5996, time : 26.0076
loss : 3.6603, time : 25.9696
loss : 3.5

loss : 3.5880, time : 26.0277
loss : 3.5657, time : 26.0350
loss : 3.5461, time : 26.1164
loss : 3.5734, time : 26.1211
loss : 3.4739, time : 26.0247
loss : 3.5974, time : 25.9952
loss : 3.5679, time : 26.0662
loss : 3.5549, time : 26.0330
loss : 3.5682, time : 26.0274
loss : 3.5256, time : 26.0085
loss : 3.5641, time : 25.9809
loss : 3.6187, time : 25.9829
loss : 3.5886, time : 25.9873
loss : 3.5752, time : 25.9826
loss : 3.6016, time : 25.9750
loss : 3.5309, time : 25.9878
loss : 3.5912, time : 26.0530
loss : 3.5272, time : 26.0179
loss : 3.6085, time : 26.0526
loss : 3.5654, time : 26.0068
loss : 3.5206, time : 25.9860
loss : 3.5401, time : 25.9998
loss : 3.6147, time : 26.0506
loss : 3.5140, time : 26.0738
loss : 3.5373, time : 26.0601
loss : 3.6111, time : 26.0506
loss : 3.5383, time : 26.1060
loss : 3.5715, time : 26.0564
loss : 3.5758, time : 26.0605
loss : 3.5422, time : 26.0462
loss : 3.5833, time : 26.1053
loss : 3.6428, time : 26.0592
loss : 3.5605, time : 26.0377
loss : 3.5

loss : 3.5171, time : 25.9992
loss : 3.5350, time : 26.0079
loss : 3.4982, time : 26.0576
loss : 3.5019, time : 26.0812
loss : 3.5245, time : 26.0700
loss : 3.5571, time : 26.0095
loss : 3.4732, time : 26.0984
loss : 3.5743, time : 26.1369
loss : 3.5019, time : 26.1344
loss : 3.5047, time : 26.2829
loss : 3.6619, time : 26.6389
loss : 3.4380, time : 26.0935
loss : 3.4965, time : 26.0233
loss : 3.5152, time : 26.0152
loss : 3.5619, time : 26.1291
loss : 3.4796, time : 26.1825
loss : 3.4922, time : 26.1144
loss : 3.5861, time : 26.0196
test loss : 3.4076555671564766
loss : 3.5044, time : 26.1919
loss : 3.5526, time : 26.0158
loss : 3.5270, time : 26.0074
loss : 3.5383, time : 26.0146
loss : 3.4804, time : 26.0014
loss : 3.5150, time : 26.0009
loss : 3.5265, time : 25.9902
loss : 3.7175, time : 26.0503
loss : 3.4909, time : 26.0200
loss : 3.4949, time : 26.0499
loss : 3.5023, time : 26.0092
loss : 3.5049, time : 26.0165
loss : 3.5090, time : 26.0721
loss : 3.5016, time : 25.9922
loss : 3.

KeyboardInterrupt: 

In [None]:
print('starting training')
writes = 0
for epoch in range(args.max_epochs):
    model.train(True)
    torch.cuda.synchronize()
    train_loss = 0.
    time_ = time.time()
    model.train()
    for batch_idx, (input,_) in enumerate(train_loader):
        input = input.to(device)
        input = Variable(input)
        output = model(input)
        output.retain_grad()
        loss = loss_op(input, output)
        optimizer.zero_grad()
        loss.backward()
        if output.grad.isnan().any():
            print('output grad are nan')
            break
        optimizer.step()
        train_loss += loss.detach().cpu().numpy()
        if (batch_idx +1) % args.print_every == 0 :
            deno = args.print_every * args.batch_size * np.prod(obs) * np.log(2.)
            print('loss : {:.4f}, time : {:.4f}'.format(
                (train_loss / deno),
                (time.time() - time_)))
            train_loss = 0.
            writes += 1
            time_ = time.time()

    if loss.isnan() or loss.isinf() or output.grad.isnan().any():
        break

    # decrease learning rate
    scheduler.step()

    torch.cuda.synchronize()
    model.eval()
    test_loss = 0.
    for batch_idx, (input,_) in enumerate(test_loader):
        input = input.to(device)
        input_var = Variable(input)
        output = model(input_var)
        loss = loss_op(input_var, output)
        test_loss += loss.detach().cpu().numpy()
        del loss, output

    deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.)
    print('test loss : {:4f}, lr : {:4e}'.format(test_loss / deno, optimizer.param_groups[0]['lr']))
    test_losses.append(test_loss / deno)

    if (epoch + 1) % args.save_interval == 0:
        torch.save(model.state_dict(), '{}/{}_{}.pth'.format(args.save_dir, model_name, epoch))
        print('sampling...')
        sample_t = sample(model)
        sample_t = rescaling_inv(sample_t)
        utils.save_image(sample_t,'/data/pixelcnn/images/{}_{}.png'.format(model_name, epoch),
                nrow=5, padding=0)

starting training
loss : 3.5708, time : 25.2757
loss : 3.5111, time : 25.2420
loss : 3.5012, time : 25.5406
loss : 3.4487, time : 25.6778
loss : 3.4521, time : 25.7615
loss : 3.5215, time : 25.8648
loss : 3.5134, time : 26.0589
loss : 3.5112, time : 25.9552
loss : 3.4566, time : 25.9715
loss : 3.4796, time : 26.0163
loss : 3.5026, time : 25.9873
loss : 3.5090, time : 26.0123
loss : 3.5645, time : 26.0454
loss : 3.5357, time : 25.9932
loss : 3.5156, time : 26.0057
loss : 3.5727, time : 25.9949
loss : 3.4599, time : 26.0072
loss : 3.4734, time : 26.0018
loss : 3.5183, time : 26.0068
loss : 3.5916, time : 26.0338
loss : 3.5090, time : 26.0869
loss : 3.5097, time : 26.5344
loss : 3.5750, time : 26.0821
loss : 3.4730, time : 26.0737
loss : 3.5588, time : 26.1316
loss : 3.5070, time : 26.6518
loss : 3.4746, time : 26.4655
loss : 3.4185, time : 26.0517
loss : 3.5334, time : 25.9831
loss : 3.4534, time : 25.9834
loss : 3.5007, time : 26.0631
loss : 3.4491, time : 26.0416
loss : 3.4862, time : 

loss : 3.4548, time : 30.1823
loss : 3.4398, time : 30.2307
loss : 3.5123, time : 30.2412
loss : 3.4906, time : 30.1605
loss : 3.4528, time : 30.1711
loss : 3.4970, time : 30.2350
loss : 3.4509, time : 30.1869
loss : 3.4826, time : 30.1709
loss : 3.4437, time : 30.1831
loss : 3.4766, time : 30.1730
loss : 3.4294, time : 30.0827
loss : 3.4519, time : 30.1815
loss : 3.4510, time : 30.0834
loss : 3.4711, time : 30.1362
loss : 3.4501, time : 30.0940
loss : 3.4675, time : 30.1427
loss : 3.5187, time : 30.0869
loss : 3.4878, time : 30.1512
loss : 3.4451, time : 30.1348
loss : 3.4435, time : 30.1318
loss : 3.4511, time : 30.2532
loss : 3.4586, time : 30.0751
loss : 3.4299, time : 30.0708
loss : 3.4282, time : 30.0884
loss : 3.4801, time : 30.0910
loss : 3.4552, time : 30.1280
loss : 3.4483, time : 30.1539
loss : 3.5804, time : 30.2659
loss : 3.4777, time : 30.3382
loss : 3.4439, time : 30.8577
loss : 3.4161, time : 30.2572
loss : 3.4939, time : 30.1293
loss : 3.4198, time : 30.1096
loss : 3.4

loss : 3.4593, time : 30.1303
loss : 3.4450, time : 30.3956
loss : 3.4580, time : 30.1667
loss : 3.4571, time : 30.1303
loss : 3.5095, time : 30.1469
loss : 3.3855, time : 30.2046
loss : 3.4685, time : 30.1615
loss : 3.4239, time : 30.1355
loss : 3.4226, time : 30.1059
loss : 3.4961, time : 30.0752
loss : 3.4277, time : 30.0160
loss : 3.4011, time : 30.0322
loss : 3.4818, time : 30.0654
loss : 3.4771, time : 30.1107
loss : 3.4716, time : 30.2000
loss : 3.4510, time : 30.2242
loss : 3.4050, time : 30.2892
loss : 3.4655, time : 30.2443
loss : 3.4226, time : 30.2187
loss : 3.4682, time : 30.5220
loss : 3.4135, time : 30.1823
loss : 3.3965, time : 30.1950
test loss : 3.352352, lr : 1.999680e-04
loss : 3.4028, time : 30.7649
loss : 3.4299, time : 30.1073
loss : 3.4285, time : 30.1689
loss : 3.4395, time : 30.1604
loss : 3.4292, time : 30.1735
loss : 3.4427, time : 30.1732
loss : 3.4657, time : 30.1353
loss : 3.4159, time : 30.2311
loss : 3.4050, time : 30.1805
loss : 3.4841, time : 30.1629


loss : 3.4183, time : 30.0539
loss : 3.4479, time : 29.8325
loss : 3.3650, time : 29.8492
loss : 3.4246, time : 29.8334
loss : 3.3897, time : 29.8795
loss : 3.3791, time : 29.8410
loss : 3.4456, time : 29.8785
loss : 3.4579, time : 29.9119
loss : 3.4400, time : 29.9939
loss : 3.4377, time : 30.0027
loss : 3.3918, time : 29.9150
loss : 3.3835, time : 29.9994
loss : 3.4313, time : 30.3347
loss : 3.3627, time : 30.3717
loss : 3.4051, time : 30.3568
loss : 3.3777, time : 30.2799
loss : 3.4399, time : 30.2062
loss : 3.4100, time : 30.3312
loss : 3.4415, time : 30.5565
loss : 3.4094, time : 30.3786
loss : 3.4372, time : 30.1706
loss : 3.4017, time : 30.0462
loss : 3.4698, time : 30.0959
loss : 3.4350, time : 30.0054
loss : 3.4433, time : 30.0957
loss : 3.4339, time : 30.1341
loss : 3.4253, time : 29.9839
loss : 3.3885, time : 30.1130
loss : 3.4188, time : 30.4090
loss : 3.4428, time : 30.1626
loss : 3.4189, time : 30.0824
loss : 3.4184, time : 30.1810
loss : 3.3859, time : 30.3698
loss : 3.4

loss : 3.3355, time : 29.6347
loss : 3.3954, time : 29.9450
loss : 3.4211, time : 29.6746
loss : 3.4264, time : 29.6922
loss : 3.3693, time : 29.7207
loss : 3.3785, time : 29.7287
loss : 3.4312, time : 29.7197
loss : 3.4064, time : 29.7582
loss : 3.3729, time : 30.3449
loss : 3.3496, time : 29.7190
loss : 3.4892, time : 29.7911
loss : 3.3560, time : 29.8094
loss : 3.3404, time : 29.9903
loss : 3.4212, time : 29.9198
loss : 3.4016, time : 29.9467
loss : 3.4385, time : 29.8828
loss : 3.3399, time : 29.9282
loss : 3.3361, time : 29.8801
loss : 3.3831, time : 30.0183
loss : 3.4837, time : 29.9852
loss : 3.3489, time : 29.9739
loss : 3.4018, time : 29.8683
loss : 3.3988, time : 29.8875
loss : 3.4753, time : 29.8400
loss : 3.3785, time : 29.8575
loss : 3.4032, time : 29.8269
loss : 3.3500, time : 29.8531
loss : 3.4115, time : 29.7902
loss : 3.3978, time : 29.8178
loss : 3.3664, time : 29.8346
loss : 3.3355, time : 29.9451
loss : 3.3514, time : 29.8435
loss : 3.4319, time : 29.8738
loss : 3.4

loss : 3.3673, time : 30.3768
loss : 3.3807, time : 29.7320
loss : 3.3854, time : 29.5217
loss : 3.3875, time : 29.5432
loss : 3.3608, time : 29.4700
loss : 3.3878, time : 29.5323
loss : 3.4295, time : 29.7084
loss : 3.3908, time : 30.3862
loss : 3.4335, time : 30.3606
loss : 3.4833, time : 30.3168
loss : 3.3428, time : 30.3043
loss : 3.3868, time : 30.3153
loss : 3.3677, time : 29.5928
loss : 3.3742, time : 29.5155
loss : 3.3683, time : 29.5203
loss : 3.3606, time : 29.5466
loss : 3.4470, time : 29.4759
loss : 3.3486, time : 29.5286
loss : 3.4313, time : 29.6021
loss : 3.3629, time : 29.5531
loss : 3.3560, time : 29.6054
loss : 3.3933, time : 29.5339
test loss : 3.432958, lr : 1.999550e-04
loss : 3.3492, time : 29.7129
loss : 3.4076, time : 29.5276
loss : 3.4411, time : 29.5212
loss : 3.3469, time : 29.5364
loss : 3.4010, time : 29.5242
loss : 3.4009, time : 29.5192
loss : 3.4084, time : 29.5238
loss : 3.3764, time : 29.5524
loss : 3.3484, time : 29.6500
loss : 3.3886, time : 29.6399


loss : 3.3191, time : 29.5164
loss : 3.3694, time : 29.5038
test loss : 3.250740, lr : 1.999510e-04
loss : 3.2914, time : 29.7334
loss : 3.3701, time : 29.5930
loss : 3.4150, time : 29.5897
loss : 3.3882, time : 29.4544
loss : 3.3526, time : 29.5132
loss : 3.4271, time : 30.4131
loss : 3.3308, time : 30.3570
loss : 3.3558, time : 29.6737
loss : 3.3724, time : 29.5471
loss : 3.3468, time : 29.7369
loss : 3.3848, time : 29.5113
loss : 3.4698, time : 29.5457
loss : 3.4006, time : 29.4933
loss : 3.3129, time : 29.8915
loss : 3.4843, time : 30.3081
loss : 3.3566, time : 30.3476
loss : 3.2988, time : 29.6641
loss : 3.3199, time : 29.5669
loss : 3.4049, time : 29.5229
loss : 3.4001, time : 29.5268
loss : 3.3482, time : 29.5443
loss : 3.3702, time : 29.5843
loss : 3.3524, time : 29.5446
loss : 3.3867, time : 29.6068
loss : 3.3749, time : 29.5860
loss : 3.3128, time : 29.5882
loss : 3.3682, time : 30.3519
loss : 3.4801, time : 30.3839
loss : 3.3299, time : 29.6444
loss : 3.4101, time : 29.9885


loss : 3.3377, time : 29.5217
loss : 3.3177, time : 29.5252
loss : 3.3586, time : 29.4927
loss : 3.3552, time : 29.5273
loss : 3.3303, time : 29.5147
loss : 3.3579, time : 29.5739
loss : 3.3591, time : 29.5747
loss : 3.3600, time : 29.4942
loss : 3.3782, time : 29.5425
loss : 3.4105, time : 29.5417
loss : 3.3700, time : 29.5485
loss : 3.2842, time : 29.5123
loss : 3.3397, time : 29.5453
loss : 3.3441, time : 29.5336
loss : 3.3306, time : 29.5292
loss : 3.3741, time : 29.4861
loss : 3.3635, time : 29.5181
loss : 3.3368, time : 29.4702
loss : 3.3577, time : 29.5439
loss : 3.3170, time : 29.5312
loss : 3.4391, time : 30.1546
loss : 3.3734, time : 29.5632
loss : 3.3350, time : 29.5632
loss : 3.3426, time : 29.5728
loss : 3.3899, time : 29.5859
loss : 3.3509, time : 29.4965
loss : 3.3190, time : 29.5380
loss : 3.3376, time : 29.4744
loss : 3.3346, time : 29.5208
loss : 3.3025, time : 29.5333
loss : 3.3902, time : 29.4959
loss : 3.2972, time : 29.5208
loss : 3.3953, time : 30.0563
loss : 3.3

loss : 3.3431, time : 29.5203
loss : 3.3235, time : 29.5438
loss : 3.2699, time : 29.5134
loss : 3.3791, time : 29.4922
loss : 3.3379, time : 29.5222
loss : 3.3478, time : 29.5881
loss : 3.3261, time : 29.5049
loss : 3.3431, time : 29.5303
loss : 3.2839, time : 29.5513
loss : 3.3897, time : 29.5507
loss : 3.3708, time : 30.0151
loss : 3.3188, time : 29.5581
loss : 3.3302, time : 29.5408
loss : 3.3218, time : 29.5146
loss : 3.2851, time : 29.5082
loss : 3.3343, time : 29.5377
loss : 3.3617, time : 29.5159
loss : 3.2608, time : 29.5053
loss : 3.3423, time : 29.5133
loss : 3.3679, time : 29.5034
loss : 3.3413, time : 29.5598
loss : 3.4013, time : 29.5569
loss : 3.3540, time : 29.5089
loss : 3.3077, time : 29.5087
loss : 3.3243, time : 29.4991
test loss : 3.249473, lr : 1.999420e-04
loss : 3.2752, time : 29.7509
loss : 3.3210, time : 29.5607
loss : 3.2849, time : 29.9564
loss : 3.3245, time : 29.5546
loss : 3.3499, time : 29.5514
loss : 3.3672, time : 29.4865
loss : 3.3712, time : 29.4978


loss : 3.2463, time : 29.4968
loss : 3.3405, time : 29.8054
loss : 3.3434, time : 29.8668
loss : 3.3273, time : 29.5390
loss : 3.3036, time : 29.5309
test loss : 3.326365, lr : 1.999380e-04
loss : 3.3577, time : 29.7848
loss : 3.3323, time : 30.3092
loss : 3.4463, time : 29.8012
loss : 3.3512, time : 29.4788
loss : 3.3159, time : 29.5102
loss : 3.2751, time : 29.5102
loss : 3.3089, time : 29.5516
loss : 3.3277, time : 29.5075
loss : 3.3113, time : 29.5858
loss : 3.3584, time : 30.1398
loss : 3.3313, time : 30.0457
loss : 3.3110, time : 29.4864
loss : 3.3837, time : 29.5196
loss : 3.3406, time : 29.4934
loss : 3.3023, time : 29.4865
loss : 3.3437, time : 29.4856
loss : 3.3831, time : 29.4917
loss : 3.3032, time : 29.5415
loss : 3.2528, time : 29.6266
loss : 3.3139, time : 29.6451
loss : 3.3124, time : 29.5589
loss : 3.2723, time : 29.4686
loss : 3.3110, time : 29.5322
loss : 3.3077, time : 29.5094
loss : 3.3315, time : 29.5249
loss : 3.2587, time : 29.4957
loss : 3.2911, time : 29.5505


loss : 3.2965, time : 30.2049
loss : 3.2872, time : 30.3213
loss : 3.3267, time : 30.3433
loss : 3.2601, time : 30.0890
loss : 3.3628, time : 29.5336
loss : 3.2652, time : 29.5785
loss : 3.2904, time : 29.6216
loss : 3.3003, time : 30.1141
loss : 3.3473, time : 29.5463
loss : 3.2765, time : 29.6616
loss : 3.2842, time : 29.4891
loss : 3.3250, time : 29.5182
loss : 3.3481, time : 29.4810
loss : 3.3088, time : 29.7499
loss : 3.3358, time : 30.3228
loss : 3.3426, time : 29.5197
loss : 3.2476, time : 29.5400
loss : 3.3222, time : 29.5176
loss : 3.3275, time : 29.5376
loss : 3.2885, time : 29.5206
loss : 3.3842, time : 29.4980
loss : 3.3187, time : 29.5016
loss : 3.3107, time : 29.5631
loss : 3.3030, time : 29.6038
loss : 3.2287, time : 29.5753
loss : 3.3075, time : 29.5413
loss : 3.2787, time : 29.5606
loss : 3.2936, time : 29.9780
loss : 3.3457, time : 29.6068
loss : 3.3422, time : 29.5108
loss : 3.3428, time : 29.5938
loss : 3.3336, time : 29.5678
loss : 3.3175, time : 29.5621
loss : 3.2

loss : 3.2688, time : 30.3462
loss : 3.2823, time : 30.3549
loss : 3.2968, time : 29.7639
loss : 3.3026, time : 29.5667
loss : 3.3523, time : 29.5757
loss : 3.2524, time : 29.5071
loss : 3.2578, time : 29.9347
loss : 3.2850, time : 29.5332
loss : 3.2912, time : 29.5070
loss : 3.3335, time : 29.6137
loss : 3.3187, time : 29.6196
loss : 3.2774, time : 29.4681
loss : 3.2780, time : 29.5059
loss : 3.2935, time : 29.5112
loss : 3.2948, time : 29.9042
loss : 3.3291, time : 29.5025
loss : 3.3151, time : 29.5438
loss : 3.2832, time : 29.5466
loss : 3.3354, time : 29.5421
loss : 3.2607, time : 29.5124
loss : 3.4127, time : 29.5181
loss : 3.2400, time : 29.5092
loss : 3.2691, time : 29.4776
loss : 3.2542, time : 29.4888
loss : 3.3359, time : 29.4606
loss : 3.2740, time : 29.5003
loss : 3.2506, time : 29.4682
loss : 3.2685, time : 29.4671
test loss : 3.280904, lr : 1.999290e-04
loss : 3.3478, time : 29.7818
loss : 3.3674, time : 29.5324
loss : 3.2920, time : 29.8716
loss : 3.2800, time : 29.6559


In [None]:
nits_model.nits_list[0].final_layer_constraint

In [None]:
test_losses

In [None]:
# torch.save(model.to('cpu').state_dict(), './cifar_model.pth')

In [None]:
input, _ = next(iter(train_loader))

with torch.no_grad():
    output = model(input.to(device))

n_steps = 300
unif_x = torch.linspace(-bounds, bounds, steps=n_steps, device=output.device).reshape(-1, 1).tile(1, 3)
params = output[0,:,10,10].reshape(-1, nits_model.tot_params).tile((n_steps, 1))
pdfs = nits_model.to(unif_x.device).pdf(unif_x, params)

colors = ['red', 'green', 'blue']
for i in range(nits_model.d):
    plt.scatter(unif_x[:,i].detach().cpu(), pdfs[:,i].detach().cpu(), c=colors[i])