In [1]:
# ---------------------------------------------------------------
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.

# This work is licensed under the NVIDIA Source Code License
# for NVAE. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Sat Dec 11 00:13:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
cd /content/drive/MyDrive/Research_copy/knnw/vae/NVAE

/content/drive/MyDrive/Research_copy/knnw/vae/NVAE


In [None]:
!pip install -r requirements.txt

#Imports

In [2]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
from tqdm.notebook import tqdm
import gc

import torch.distributed as dist
from torch.multiprocessing import Process
from torch.cuda.amp import autocast, GradScaler

from model import AutoEncoder
from thirdparty.adamax import Adamax
import utils
import datasets

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '6020'

torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=0, world_size=1)

import numpy as np
from PIL import Image
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from scipy.io import loadmat
import os
import urllib
from lmdb_datasets import LMDBDataset
from thirdparty.lsun import LSUN

from sklearn.model_selection import train_test_split 

#Dataset and Dataloaders

In [3]:
def get_loaders(args):
    """Get data loaders for required dataset."""
    return get_loaders_eval(args.dataset, args)

class MyImageFolder(dset.ImageFolder):
    def __getitem__(self, index):
        return super(MyImageFolder, self).__getitem__(index) #return image path


class KNNWDataset(Dataset):
  def __init__(self,X, transform_train):
    X = np.load(X, allow_pickle=True)
    train_data, _ = train_test_split(X, test_size=0.2, shuffle=False)
    self.x = train_data
    self.transform = transform_train
    # self.transform = transform

  def __len__(self):
    return len(self.x)

  def __getitem__(self,index):
    x = self.x[index]
    # x = x/255
    x = self.transform(x).permute(1,2,0)
    return (x, 0)



def get_loaders_eval(dataset, args):
    """Get train and valid loaders for cifar10/tiny imagenet."""

    if dataset == 'cifar10':
        num_classes = 10
        train_transform, valid_transform = _data_transforms_cifar10(args)
        train_data = dset.CIFAR10(
            root=args.data, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(
            root=args.data, train=False, download=True, transform=valid_transform)
      #-----------------------------------
    elif dataset.startswith('knnw'):
        num_classes = 0

        train_transform, _ = _data_transforms_knnw()

        train_data = KNNWDataset(args.data, train_transform)
        valid_data = train_data
        
      #--------------------------------------------------
    else:
        raise NotImplementedError

    train_sampler, valid_sampler = None, None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler, pin_memory=True, num_workers=8, drop_last=True)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler, pin_memory=True, num_workers=1, drop_last=False)

    return train_queue, valid_queue, num_classes


def _data_transforms_cifar10(args):
    """Get data transforms for cifar10."""

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    return train_transform, valid_transform


def _data_transforms_generic(size):
    train_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform

    #----------------------------
def _data_transforms_knnw():
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform

#---------------------------------------

#Model initialization

In [13]:

class args:

    def __init__(self):

        self.ada_groups=False
        self.arch_instance='res_mbconv'
        self.batch_size=2
        self.cont_training=False
        self.data='/content/unique_32x32.npy'
        self.dataset='knnw'
        self.epochs=400
        self.fast_adamax=True
        self.global_rank=0
        self.kl_anneal_portion=0.3
        self.kl_const_coeff=0.0001
        self.kl_const_portion=0.0001
        self.learning_rate=0.01
        self.learning_rate_min=0.0001
        self.local_rank=0
        self.master_address='127.0.0.1'
        self.min_groups_per_scale=1
        self.node_rank=0
        self.num_cell_per_cond_dec=2
        self.num_cell_per_cond_enc=2
        self.num_channels_dec=32
        self.num_channels_enc=32
        self.num_groups_per_scale=30
        self.num_latent_per_group=20
        self.num_latent_scales=1
        self.num_mixture_dec=10
        self.num_nf=1
        self.num_postprocess_blocks=1
        self.num_postprocess_cells=2
        self.num_preprocess_blocks=1
        self.num_preprocess_cells=2
        self.num_proc_node=1
        self.num_process_per_node=8
        self.num_x_bits=8
        self.res_dist=True
        self.root='PATH_TO_CHECKPOINT_DIR'
        self.save='UNIQUE_EXPR_ID'
        self.seed=1
        self.use_se=False
        self.warmup_epochs=5
        self.weight_decay=0.0003
        self.weight_decay_norm_anneal=False
        self.weight_decay_norm=0.01
        self.distributed = False

args = args()
writer = None
arch_instance = utils.get_arch_cells(args.arch_instance)


In [18]:
cd /content

/content


In [None]:
os.mkdir(args.root)
os.mkdir(args.save)
os.mkdir(args.data)

In [10]:
!gdown --id 1iQu3EzT1VHUZnCyBvYXIGom2_V0Z79wg &     #Unique Label Data

Downloading...
From: https://drive.google.com/uc?id=1iQu3EzT1VHUZnCyBvYXIGom2_V0Z79wg
To: /content/knnw_720p_qscale31_unique.tar.gz
100% 2.08G/2.08G [00:26<00:00, 77.5MB/s]


In [None]:
!tar -xzvf "/content/knnw_720p_qscale31_unique.tar.gz" -C "/content"     #Unzip Unique Label Data

In [24]:
!gdown --id 1ZqUAj6yZItzFG9b4OJdnEEwSumUMEuVO #unique_32x32.npy

Downloading...
From: https://drive.google.com/uc?id=1ZqUAj6yZItzFG9b4OJdnEEwSumUMEuVO
To: /content/unique_32x32.npy
100% 336M/336M [00:02<00:00, 119MB/s]


In [12]:
!mkdir val_data

In [None]:
!mv qscale31_unique val_data/.

In [14]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

logging = utils.Logger(args.global_rank, args.save)
writer = utils.Writer(args.global_rank, args.save)

# Get data loaders.
print(args.batch_size)
train_queue, valid_queue, num_classes = datasets.get_loaders(args)
args.num_total_iter = len(train_queue) * args.epochs
warmup_iters = len(train_queue) * args.warmup_epochs
swa_start = len(train_queue) * (args.epochs - 1)

arch_instance = utils.get_arch_cells(args.arch_instance)

model = AutoEncoder(args, writer, arch_instance)
model = model.cuda()

2
len log norm: 610
len bn: 364


#Training loop

In [15]:
def train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, args, logging, cnn_scheduler, checkpoint_file, arch_instance):
    alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales,
                                      groups_per_scale=model.groups_per_scale, fun='square')
    nelbo = utils.AvgrageMeter()
    model.train()
    for step, x in tqdm(enumerate(train_queue)):
        # print('Length X', len(x))
        x = x[0] if len(x) > 1 else x
        x = x.cuda()
        # print('X shape', x.shape)

        # change bit length
        x = utils.pre_process(x, args.num_x_bits)

        # warm-up lr
        if global_step < warmup_iters:
            lr = args.learning_rate * float(global_step) / warmup_iters
            for param_group in cnn_optimizer.param_groups:
                param_group['lr'] = lr

        # sync parameters, it may not be necessary
        if step % 100 == 0:
            utils.average_params(model.parameters(), args.distributed)

        cnn_optimizer.zero_grad()
        with autocast():
            logits, log_q, log_p, kl_all, kl_diag = model(x)

            output = model.decoder_output(logits)
            kl_coeff = utils.kl_coeff(global_step, args.kl_anneal_portion * args.num_total_iter,
                                      args.kl_const_portion * args.num_total_iter, args.kl_const_coeff)

            recon_loss = utils.reconstruction_loss(output, x, crop=model.crop_output)
            balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i)

            nelbo_batch = recon_loss + balanced_kl
            loss = torch.mean(nelbo_batch)
            norm_loss = model.spectral_norm_parallel()
            bn_loss = model.batchnorm_loss()
            # get spectral regularization coefficient (lambda)
            if args.weight_decay_norm_anneal:
                assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
                wdn_coeff = (1. - kl_coeff) * np.log(args.weight_decay_norm_init) + kl_coeff * np.log(args.weight_decay_norm)
                wdn_coeff = np.exp(wdn_coeff)
            else:
                wdn_coeff = args.weight_decay_norm

            loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff

        grad_scalar.scale(loss).backward()
        utils.average_gradients(model.parameters(), args.distributed)
        del x, balanced_kl, kl_coeffs, kl_vals, logits, log_q, log_p, kl_all, kl_diag
        gc.collect()
        torch.cuda.empty_cache()
        grad_scalar.step(cnn_optimizer)
        grad_scalar.update()
        nelbo.update(loss.data, 1)



        if (global_step + 1) % 100 == 0:

            logging.info('saving the model.')
            torch.save({'epoch': 100, 'state_dict': model.state_dict(),
                    'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step,
                    'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(),
                    'grad_scalar': grad_scalar.state_dict()}, checkpoint_file + "2")

            logging.info('train %d %f', global_step, nelbo.avg)


        global_step += 1

    utils.average_tensor(nelbo.avg, args.distributed)
    return nelbo.avg, global_step

#Run Training

In [16]:
# ensures that weight initializations are all the same
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

logging = utils.Logger(args.global_rank, args.save)

# Get data loaders.
train_queue, valid_queue, num_classes = datasets.get_loaders(args)
args.num_total_iter = len(train_queue) * args.epochs
warmup_iters = len(train_queue) * args.warmup_epochs
swa_start = len(train_queue) * (args.epochs - 1)

arch_instance = utils.get_arch_cells(args.arch_instance)

model = AutoEncoder(args, writer, arch_instance)
model = model.cuda()

logging.info('args = %s', args)
logging.info('param size = %fM ', utils.count_parameters_in_M(model))
logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale))

if args.fast_adamax:
  # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster.
  cnn_optimizer = Adamax(model.parameters(), args.learning_rate,
                          weight_decay=args.weight_decay, eps=1e-3)
else:
  cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate,
                                      weight_decay=args.weight_decay, eps=1e-3)

cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min)
grad_scalar = GradScaler(2**10)

num_output = utils.num_output(args.dataset)
bpd_coeff = 1. / np.log(2.) / num_output

# if load
checkpoint_file = os.path.join(args.save, 'checkpoint.pt')
if args.cont_training:
  logging.info('loading the model.')
  checkpoint = torch.load(checkpoint_file, map_location='cpu')
  init_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['state_dict'])
  model = model.cuda()
  cnn_optimizer.load_state_dict(checkpoint['optimizer'])
  grad_scalar.load_state_dict(checkpoint['grad_scalar'])
  cnn_scheduler.load_state_dict(checkpoint['scheduler'])
  global_step = checkpoint['global_step']
else:
  global_step, init_epoch = 0, 0

for epoch in range(init_epoch, args.epochs):
  # update lrs.
  if args.distributed:
      train_queue.sampler.set_epoch(global_step + args.seed)
      valid_queue.sampler.set_epoch(0)

  if epoch > args.warmup_epochs:
      cnn_scheduler.step()

  # Logging.
  logging.info('epoch %d', epoch)

  # Training.
  train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, args, logging, cnn_scheduler, checkpoint_file, arch_instance)
  logging.info('train_nelbo %f', train_nelbo)

  model.eval()
  # generate samples less frequently
  eval_freq = 1 if args.epochs <= 50 else 20
  if epoch % eval_freq == 0 or epoch == (args.epochs - 1):
      with torch.no_grad():
          num_samples = 16
          n = int(np.floor(np.sqrt(num_samples)))
          for t in [0.7, 0.8, 0.9, 1.0]:
              logits = model.sample(num_samples, t)
              output = model.decoder_output(logits)
              output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) else output.sample(t)
              output_tiled = utils.tile_image(output_img, n)

      valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging)
      logging.info('valid_nelbo %f', valid_nelbo)
      logging.info('valid neg log p %f', valid_neg_log_p)
      logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff)
      logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff)

  save_freq = int(np.ceil(args.epochs / 100))
  if epoch % save_freq == 0 or epoch == (args.epochs - 1):
      if args.global_rank == 0:
          logging.info('saving the model.')
          torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict(),
                      'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step,
                      'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(),
                      'grad_scalar': grad_scalar.state_dict()}, checkpoint_file)

# Final validation
valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, logging=logging)
logging.info('final valid nelbo %f', valid_nelbo)
logging.info('final valid neg log p %f', valid_neg_log_p)

len log norm: 610
len bn: 364
12/11 12:44:30 AM (Elapsed: 00:00:01) args = <__main__.args object at 0x7f35bd582850>
12/11 12:44:30 AM (Elapsed: 00:00:01) param size = 10.823000M 
12/11 12:44:30 AM (Elapsed: 00:00:01) groups per scale: [30], total_groups: 30
12/11 12:44:30 AM (Elapsed: 00:00:01) epoch 0


0it [00:00, ?it/s]

KeyboardInterrupt: ignored

#Evaluation

Run evaluation using knnw_nvae_eval.ipynb