In [1]:
!nvidia-smi

Thu Jan 20 18:57:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:1B:00.0 Off |                  N/A |
| 29%   37C    P8     1W / 250W |   5955MiB / 11016MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:1C:00.0 Off |                  N/A |
| 43%   70C    P2   263W / 250W |   8251MiB / 11019MiB |     93%      Default |
|       

In [2]:
device = 'cuda:4'

In [3]:
import time
import os, shutil
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from nits.pixelcnn_model import *
from nits.model import NITS, ConditionalNITS
from nits.discretized_mol import discretized_nits_loss, nits_sample
from PIL import Image

import matplotlib.pyplot as plt

def list_str_to_list(s):
    print(s)
    assert s[0] == '[' and s[-1] == ']'
    s = s[1:-1]
    s = s.replace(' ', '')
    s = s.split(',')

    s = [int(x) for x in s]

    return s

parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-g', '--gpu', type=str,
                    default='', help='Location for the dataset')
parser.add_argument('-i', '--data_dir', type=str,
                    default='/data/pixelcnn/data/', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='/data/pixelcnn/models/',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be cifar / mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Restore training from previous model checkpoint?')

# pixelcnn model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=(1 - 5e-6),
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=16,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')

# nits model
parser.add_argument('-a', '--nits_arch', type=list_str_to_list, default='[8,8,1]',
                    help='Architecture of NITS model')
parser.add_argument('-nb', '--nits_bound', type=float, default=5.,
                    help='Upper and lower bound of NITS model')
parser.add_argument('-c', '--constraint', type=str, default='neg_exp',
                    help='Upper and lower bound of NITS model')
parser.add_argument('-fc', '--final_constraint', type=str, default='softmax',
                    help='Upper and lower bound of NITS model')


args = parser.parse_args(['-a', '[16,10,1]', '-g', '4'])

device = 'cuda:' + args.gpu if args.gpu else 'cpu'
print('device:', device)

# HOUSEKEEPING

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'lr_{:.5f}_nr_resnet{}_nr_filters{}_nits_arch{}_constraint{}_final_constraint{}'.format(
    args.lr, args.nr_resnet, args.nr_filters, args.nits_arch, args.constraint, args.final_constraint)
if os.path.exists(os.path.join('runs_test', model_name)):
    shutil.rmtree(os.path.join('runs_test', model_name))

sample_batch_size = 25
obs = (1, 28, 28) if 'mnist' in args.dataset else (3, 32, 32)
input_channels = obs[0]
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5
kwargs = {'num_workers':1, 'pin_memory':True, 'drop_last':True}
ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])

if 'mnist' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
elif 'cifar' in args.dataset :
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
        download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
else :
    raise Exception('{} dataset not in {mnist, cifar10}'.format(args.dataset))


# INITIALIZE NITS MODEL
if 'mnist' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = NITS(d=1, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                      A_constraint=args.constraint, arch=arch,
                      final_layer_constraint=args.final_constraint).to(device)
elif 'cifar' in args.dataset:
    arch = [1] + args.nits_arch
    nits_model = ConditionalNITS(d=3, start=-args.nits_bound, end=args.nits_bound, monotonic_const=1e-5,
                                 A_constraint=args.constraint, arch=arch, autoregressive=True,
                                 pixelrnn=True, normalize_inverse=True,
                                 final_layer_constraint=args.final_constraint,
                                 softmax_temperature=True).to(device)
tot_params = nits_model.tot_params
loss_op = lambda real, params: discretized_nits_loss(real, params, nits_model)
sample_op = lambda params: nits_sample(params, nits_model)

# INITIALIZE PIXELCNN MODEL
model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
                 input_channels=input_channels, nr_logistic_mix=tot_params, num_mix=1)
model = model.to(device)

if args.load_params:
    load_part_of_model(model, args.load_params)
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)
test_losses = []

def sample(model):
    model.train(False)
    with torch.no_grad():
        data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
        data = data.to(device)
        for i in range(obs[1]):
            for j in range(obs[2]):
                data_v = Variable(data)
                out   = model(data_v, sample=True)
                out_sample = sample_op(out)
                data[:, :, i, j] = out_sample.data[:, :, i, j]
        return data


[16,10,1]
device: cuda:4
Files already downloaded and verified


In [None]:
print('starting training')
writes = 0
for epoch in range(args.max_epochs):
    model.train(True)
    torch.cuda.synchronize()
    train_loss = 0.
    time_ = time.time()
    model.train()
    for batch_idx, (input,_) in enumerate(train_loader):
        input = input.to(device)
        input = Variable(input)
        output = model(input)
        output.retain_grad()
        loss = loss_op(input, output)
        optimizer.zero_grad()
        loss.backward()
        if output.grad.isnan().any():
            print('output grad are nan')
            break
        optimizer.step()
        train_loss += loss.detach().cpu().numpy()
        if (batch_idx +1) % args.print_every == 0 :
            deno = args.print_every * args.batch_size * np.prod(obs) * np.log(2.)
            print('loss : {:.4f}, time : {:.4f}'.format(
                (train_loss / deno),
                (time.time() - time_)))
            train_loss = 0.
            writes += 1
            time_ = time.time()

    if loss.isnan() or loss.isinf() or output.grad.isnan().any():
        break

    # decrease learning rate
    scheduler.step()

    torch.cuda.synchronize()
    model.eval()
    test_loss = 0.
    for batch_idx, (input,_) in enumerate(test_loader):
        input = input.to(device)
        input_var = Variable(input)
        output = model(input_var)
        loss = loss_op(input_var, output)
        test_loss += loss.detach().cpu().numpy()
        del loss, output

    deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.)
    print('test loss : {:4f}, lr : {:4e}'.format(test_loss / deno, optimizer.param_groups[0]['lr']))
    test_losses.append(test_loss / deno)

    if (epoch + 1) % args.save_interval == 0:
        torch.save(model.state_dict(), '{}/{}_{}.pth'.format(args.save_dir, model_name, epoch))
        print('sampling...')
        sample_t = sample(model)
        sample_t = rescaling_inv(sample_t)
        utils.save_image(sample_t,'/data/pixelcnn/images/{}_{}.png'.format(model_name, epoch),
                nrow=5, padding=0)

starting training
loss : 3.0668, time : 28.3734
loss : 3.0410, time : 28.3665
loss : 3.0675, time : 28.4162
loss : 3.0764, time : 28.4139
loss : 3.0381, time : 28.4378
loss : 3.0558, time : 28.5071
loss : 3.1054, time : 28.4681
loss : 3.1033, time : 28.4500
loss : 3.1288, time : 28.4636
loss : 3.0957, time : 28.5035
loss : 3.1207, time : 28.5214
loss : 3.0875, time : 28.4775
loss : 3.0969, time : 28.4749
loss : 3.0554, time : 28.5275
loss : 3.0987, time : 28.4911
loss : 3.1119, time : 28.4661
loss : 3.0588, time : 28.4763
loss : 3.0877, time : 28.4838
loss : 3.0757, time : 28.4802
loss : 3.1376, time : 28.4779
loss : 3.1525, time : 28.4722
loss : 3.1153, time : 28.4694
loss : 3.0585, time : 28.4835
loss : 3.0988, time : 28.4736
loss : 3.0847, time : 28.4710
loss : 3.0616, time : 28.5035
loss : 3.1251, time : 28.4637
loss : 3.1265, time : 28.4621
loss : 3.1016, time : 28.4645
loss : 3.0890, time : 28.4701
loss : 3.0740, time : 28.5466
loss : 3.1018, time : 28.4863
loss : 3.0500, time : 

loss : 3.0695, time : 28.4736
loss : 3.0755, time : 28.4810
loss : 3.0455, time : 28.4700
loss : 3.0915, time : 28.4845
loss : 3.1149, time : 28.4436
loss : 3.1332, time : 28.4608
loss : 3.0566, time : 28.4380
loss : 3.0387, time : 28.4376
loss : 3.1125, time : 28.4677
loss : 3.0939, time : 28.4634
loss : 3.0283, time : 28.4509
loss : 3.0970, time : 28.4662
loss : 3.0796, time : 28.4497
loss : 3.1181, time : 28.4738
loss : 3.0792, time : 28.4583
loss : 3.0585, time : 28.4550
loss : 3.0974, time : 28.4406
loss : 3.0786, time : 28.4784
loss : 3.0964, time : 28.4677
loss : 3.0538, time : 28.4589
loss : 3.0646, time : 28.4865
loss : 3.0945, time : 28.4842
loss : 3.0470, time : 28.4695
loss : 3.0807, time : 28.4701
loss : 3.0618, time : 28.4543
loss : 3.0230, time : 28.4189
loss : 3.1016, time : 28.4386
loss : 3.0765, time : 28.4472
loss : 3.0440, time : 28.4571
loss : 3.0589, time : 28.4876
loss : 3.0749, time : 28.4529
loss : 3.0566, time : 28.4603
loss : 3.0324, time : 28.4511
loss : 3.0

loss : 3.0655, time : 28.4490
loss : 3.0493, time : 28.4448
loss : 3.0723, time : 28.4341
loss : 3.0636, time : 28.4606
loss : 3.1048, time : 28.4580
loss : 3.0993, time : 28.4562
loss : 3.0683, time : 28.4564
loss : 3.0781, time : 28.4356
loss : 3.0757, time : 28.4408
loss : 3.0745, time : 28.4481
loss : 3.0999, time : 28.4898
loss : 3.1161, time : 28.4586
loss : 3.0619, time : 28.4618
loss : 3.1318, time : 28.4785
loss : 3.0613, time : 28.4643
loss : 3.0242, time : 28.4744
loss : 3.0766, time : 28.4736
loss : 3.1054, time : 28.4459
test loss : 3.068322, lr : 1.999410e-04
loss : 3.0840, time : 28.6445
loss : 3.0597, time : 28.4409
loss : 3.0983, time : 28.3828
loss : 3.0824, time : 28.4237
loss : 3.1313, time : 28.3933
loss : 3.1527, time : 28.3772
loss : 3.0834, time : 28.3575
loss : 3.0864, time : 28.3694
loss : 3.0599, time : 28.3792
loss : 3.1227, time : 28.4285
loss : 3.0882, time : 28.3764
loss : 3.0563, time : 28.3574
loss : 3.0782, time : 28.4354
loss : 3.0772, time : 28.4888


In [None]:
# torch.save(model.to('cpu').state_dict(), './cifar_model.pth')

In [5]:
test_losses

[3.9946067419982145,
 3.7095731649833557,
 3.5397387527198334,
 3.4285993233309866,
 3.506644070579394,
 3.355817989354682,
 3.332827659111572,
 3.318594494726047,
 3.3263378349837764,
 3.337990469515122,
 3.241161852473759,
 3.2009310717845274,
 3.215256899540375,
 3.2043920480781005,
 3.2180318627635573,
 3.1785560517355735,
 3.192287648196525,
 3.2052902351894015,
 3.181655520497033,
 3.1468223925211043,
 3.1636116053244465,
 3.127488230428158,
 3.184809518048376,
 3.1293779934350847,
 3.1181206981426617,
 3.1670931901373796,
 3.1238606558636377,
 3.116251112596015,
 3.1218776676028988,
 3.094179535355388,
 3.103162447551413,
 3.095869164944386,
 3.1980283989486136,
 3.169104161959829,
 3.0808783839082237,
 3.10101412626816,
 3.097537354029577,
 3.1172043411377937,
 3.073054336381765,
 3.0977771033571444,
 3.064713736758534,
 3.064957556709656,
 3.0588753182349704,
 3.065151132737794,
 3.0964910075394005,
 3.1430258023739324,
 3.0580621464108084,
 3.076979823850915,
 3.0595846253999

In [None]:
input, _ = next(iter(train_loader))

with torch.no_grad():
    output = model(input.to(device))

n_steps = 300
unif_x = torch.linspace(-bounds, bounds, steps=n_steps, device=output.device).reshape(-1, 1).tile(1, 3)
params = output[0,:,10,10].reshape(-1, nits_model.tot_params).tile((n_steps, 1))
pdfs = nits_model.to(unif_x.device).pdf(unif_x, params)

colors = ['red', 'green', 'blue']
for i in range(nits_model.d):
    plt.scatter(unif_x[:,i].detach().cpu(), pdfs[:,i].detach().cpu(), c=colors[i])