# Import

In [1]:
import os
import sys
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

import flow_ssl
from experiments.train_flows import utils

import torch
from torch import distributions
import torch.nn as nn
from torch.nn.modules.utils import _pair, _quadruple
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import SVHN
from torch.utils.data import Dataset

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns

import wandb
import warnings
warnings.filterwarnings('ignore')

from dataloader import get_transform, get_loader
from utils import seed_everything, get_percentile, schedule, MedianPool2d, AverageMeter

# Argument Parsing

In [2]:
def parse_args(notebook=False, print_=False):
    
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset', type=str, default="cifar10", metavar='DATA',
                        help='Dataset name (lower case, default: cifar0),\
                        opt : [cifar10, svhn, mnist, fmnist]')
    parser.add_argument('--data_path', type=str, default=None, metavar='PATH',
                        help='path to datasets location (default: None)')

    parser.add_argument('--ood_dataset', type=str, default="svhn", metavar='DATA',
                        help='OOD dataset name (lower case, default: svhn),\
                        opt : [cifar10, svhn, mnist, fmnist]')
    parser.add_argument('--ood_data_path', type=str, default=None, metavar='PATH',
                        help='path to ood datasets location (default: None)')
    
    parser.add_argument('--logdir', type=str, default=None, metavar='PATH',
                        help='path to log directory (default: None)')
    parser.add_argument('--ckptdir', type=str, default=None, metavar='PATH',
                        help='path to ckpt directory (default: None)')

    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--max_grad_norm', type=float, default=100., help='Max gradient norm for clipping')
    parser.add_argument('--num_epochs', default=101, type=int, help='Number of epochs to train')
    parser.add_argument('--num_samples', default=20, type=int, help='Number of samples at test time')
    parser.add_argument('--num_workers', default=8, type=int, help='Number of data loader threads')
    parser.add_argument('--resume',  type=str, default=None, metavar='PATH', help='path to ckpt')
    parser.add_argument('--weight_decay', default=5e-5, type=float,
                        help='L2 regularization (only applied to the weight norm scale factors)')

    parser.add_argument('--save_freq', default=25, type=int, help='frequency of saving ckpts')
    parser.add_argument('--negative_val', default=-100_000, type=int, help='Negative loss threshold')
    
    parser.add_argument('--flow', type=str, default="RealNVP", help="Flow model to use (default: RealNVP) \
                        choices=['RealNVP', 'Glow', 'RealNVPNewMask', 'RealNVPNewMask2', 'RealNVPSmall']")
    parser.add_argument('--num_blocks', default=8, type=int, help='number of blocks in ResNet')
    parser.add_argument('--num_scales', default=3, type=int, help='number of scales in multi-layer architecture')
    parser.add_argument('--num_mid_channels', default=64, type=int, help='number of channels \
                                                                          in coupling layer parametrizing network')
    parser.add_argument('--no_batchnorm', action='store_true')
    parser.add_argument('--st_type', choices=['highway', 'resnet', 'convnet'], default='resnet')
    parser.add_argument('--aug', action='store_true')
    parser.add_argument('--init_zeros', action='store_true')
    parser.add_argument('--optim', choices=['Adam', 'RMSprop'], default='Adam')
    parser.add_argument('--lr_anneal', action='store_true')

    args = parser.parse_args([]) if notebook else parser.parse_args()
    
    if print_:
        parser.print_help()
        
    return args

In [3]:
args = parse_args(True, print_=False)

args.data_path = 'experiments/datasets'
args.ood_data_path = 'experiments/datasets'
args.ckptdir = 'ckpt'
args.logdir = 'log'
args.lr_anneal = True
args.lr = 5e-5
args.save_freq = 10

In [4]:
case = f'IN-{args.dataset}_OOD-{args.ood_dataset}_epochs-{args.num_epochs}_flow-{args.flow}_lr-{args.lr}_anneal-{args.lr_anneal}_\nblocks-{args.num_blocks}_nscales-{args.num_scales}_st-{args.st_type}_aug-{args.aug}_optim-{args.optim}'

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [5]:
use_wandb = False
if use_wandb:
    wandb.init(project="NF_neg_training", entity="jskim0406", name=f'{case}')

In [6]:
seed_everything(0)

# Data load

In [8]:
transform_train_c10, transform_test_c10, img_shape = get_transform(args, 'cifar10')
transform_train_svhn, transform_test_svhn, _ = get_transform(args, 'svhn')

In [9]:
trainloader, testloader, _ = get_loader(args, 'cifar10', transform_train_c10, transform_test_c10)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
ood_trainloader, ood_testloader, _ = get_loader(args, 'svhn', transform_train_svhn, transform_test_svhn)

Using downloaded and verified file: experiments/datasets/train_32x32.mat
Using downloaded and verified file: experiments/datasets/test_32x32.mat


# Define Model

In [11]:
model_cfg = getattr(flow_ssl, args.flow)

In [12]:
if 'RealNVP' in args.flow:
    net = model_cfg(in_channels=img_shape[0], 
                    init_zeros=args.init_zeros, 
                    mid_channels=args.num_mid_channels,
                    num_scales=args.num_scales, 
                    st_type=args.st_type, 
                    use_batch_norm=not args.no_batchnorm)
    
elif args.flow == 'Glow':
    net = model_cfg(image_shape=img_shape, 
                    mid_channels=args.num_mid_channels, 
                    num_scales=args.num_scales,
                    num_coupling_layers_per_scale=args.num_coupling_layers_per_scale, 
                    num_layers=args.num_blocks,
                    multi_scale=not args.no_multi_scale, 
                    st_type=args.st_type)

print(f'Model contains {format(sum([p.numel() for p in net.parameters()]), ",d")} parameters')

Model contains 87,859,080 parameters


# Define Prior

In [13]:
D = int(np.prod(img_shape))

prior = distributions.MultivariateNormal(torch.zeros(D).to(device), torch.eye(D).to(device))

# Define Loss

In [14]:
class FlowLoss(nn.Module):
    """Get the NLL loss for a RealNVP model.

    Args:
        k (int or float): Number of discrete values in each input dimension.
            E.g., `k` is 256 for natural images.

    See Also:
        Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
    """
    # Get 'Bits Per Dimension(BPD)' by subtracting "np.log(self.k) * np.prod(z.size()[1:])"
    # ref : https://github.com/openai/glow/issues/43

    def __init__(self, prior, k=256):
        super().__init__()
        self.k = k
        self.prior = prior

    def forward(self, z, sldj, y=None, mean=True):
        z = z.reshape((z.shape[0], -1))
        # prior_ll : negative value(log prob), -inf에서 0으로 가까워질 수록 data likelihood가 높아지는 것을 의미함
        if y is not None:
            prior_ll = self.prior.log_prob(z, y)
        else:
            prior_ll = self.prior.log_prob(z)
            
        corrected_prior_ll = prior_ll - np.log(self.k) * np.prod(z.size()[1:]) 

        ll = corrected_prior_ll + sldj
        nll = -ll.mean() if mean else -ll  # nll : positive value(), inf에서 0으로 가까워 질수록 data likelihood가 높아지는 것을 의미함

        return nll

In [15]:
loss_fn = FlowLoss(prior)

# Define Optimizer

In [16]:
if 'RealNVP' in args.flow:
    # We need this to make sure that weight decay is only applied to g -- norm parameter in Weight Normalization
    # RealNVP paper: https://arxiv.org/abs/1605.08803 "3.7 Batch Normalization"
    # RealNVP use ResNet with Batch Normalization and Weight Normalization
    param_groups = utils.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    if args.optim == 'Adam':
        optimizer = optim.Adam(param_groups, lr=args.lr)
    else:
        optimizer = optim.RMSprop(param_groups, lr=args.lr)

elif args.flow == 'Glow':
    if args.optim == 'Adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)

In [17]:
print(f'len(param_groups) : {len(param_groups)}')
print(f'param_groups[0].keys : {param_groups[0].keys()}')
print(f'param_groups[1].keys : {param_groups[1].keys()}')

len(param_groups) : 2
param_groups[0].keys : dict_keys(['name', 'params', 'weight_decay', 'lr', 'betas', 'eps', 'amsgrad'])
param_groups[1].keys : dict_keys(['name', 'params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])


# Train

In [18]:
def train(epoch, net, trainloader, ood_loader, device, optimizer, loss_fn, 
         max_grad_norm, negative_val=-1e5, num_samples=10, log_freq=100):

    print(f'\nEPOCH : {epoch}')
    
    net.train()
    loss_meter, loss_positive_meter, loss_negative_meter = utils.AverageMeter(), utils.AverageMeter(), utils.AverageMeter()
    pooler = MedianPool2d(7, padding=3)
    
    iter_count, batch_count = 0, 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for (x, _), (x_transposed, _) in zip(trainloader, ood_loader):

            bs = x.shape[0]
            iter_count+=1
            batch_count+=bs
            
            x = torch.cat((x, x_transposed), dim=0)
            
            x = x.to(device)
            
            optimizer.zero_grad()
            
            z = net(x)
            sldj = net.logdet()
            
            # NLL, positive value, inf -> 0으로 갈 수록 density estim better
            loss = loss_fn(z, sldj=sldj, mean=False)
            
            loss[bs:] *= (-1)
            loss_positive = loss[:bs]  # NLL Loss for IN (positive value), inf -> 0으로 갈수록 better(LL maximization)
            loss_negative = loss[bs:]  # -NLL Loss for OoD (negative value), -inf로 갈수록 better(LL minimization)
            
            # Indicator function (eq. 7)
            if (loss_negative > negative_val).sum() > 0:   # threshold(negative_val, -100_000)보다 작은 Loss인 경우
                loss_negative = loss_negative[loss_negative > negative_val]
                loss_negative = loss_negative.mean()
                loss_positive = loss_positive.mean()
                loss = 0.5*(loss_positive + loss_negative)
            else:
                loss_negative = torch.tensor(0.)   # OoD NLL이 극단적으로 작아지는 경우(exploding), 0으로 대체
                loss_positive = loss_positive.mean()
                loss = loss_positive
                
            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            
            
            # Log
            loss_meter.update(loss.item(), bs)
            loss_positive_meter.update(loss_positive.item(), bs)
            loss_negative_meter.update(loss_negative.item(), bs)
            progress_bar.set_postfix(
                pos_bpd=utils.bits_per_dim(x[:bs], loss_positive_meter.avg),
                neg_bpd=utils.bits_per_dim(x[bs:], -loss_negative_meter.avg),
                neg_loss=loss_negative.mean().item())
            progress_bar.update(bs)

            if iter_count % log_freq == 0 or batch_count == len(trainloader.dataset):
                if use_wandb:
                    wandb.log({'epoch' : epoch,
                               'train | loss' : loss_meter.avg, 
                               'train | loss_Pos' : loss_positive_meter.avg,
                               'train | loss_Neg' : loss_negative_meter.avg,
                               'train | bpd_Pos' : utils.bits_per_dim(x[:bs], loss_positive_meter.avg),
                               'train | bpd_Neg' : utils.bits_per_dim(x[bs:], -loss_negative_meter.avg)})


In [19]:
def test(epoch, net, testloader, device, loss_fn, mode='in'):
    net.eval()
    loss_meter = utils.AverageMeter()
    loss_list = []
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, _ in testloader:
                x = x.to(device)
                z = net(x)
                sldj = net.logdet()
                losses = loss_fn(z, sldj=sldj, mean=False)  
                loss_list.extend([loss.item() for loss in losses])
                
                loss = losses.mean()   # loss =: NLL (positive value)
                loss_meter.update(loss.item(), x.size(0))
                
                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=utils.bits_per_dim(x, loss_meter.avg))
                progress_bar.update(x.size(0))

    likelihoods = -torch.from_numpy(np.array(loss_list)).float()  # -NLL (negative value)
    
    if use_wandb:
        wandb.log({'epoch' : epoch,
                   f'test|loss_{mode}' : loss_meter.avg,
                   f'test|bpd_{mode}' : utils.bits_per_dim(x, loss_meter.avg),
                   f'test|likelihoods_{mode}' : likelihoods})
    
    return likelihoods

## training / test loop

In [20]:
start_epoch = 0
net.to(device)
seed_everything(0)

for epoch in range(start_epoch, start_epoch + args.num_epochs + 1):
    
    if args.lr_anneal:
        lr = schedule(args, epoch)
        utils.adjust_learning_rate(optimizer, lr)
        
    train(epoch, net, trainloader, ood_trainloader, device, optimizer, loss_fn, 
          args.max_grad_norm, num_samples=args.num_samples, 
          negative_val=args.negative_val)
    
    test_ll = test(epoch, net, testloader, device, loss_fn, mode='in')  # LL (Not NLL)
    test_ll_percentile = get_percentile(test_ll)  # 하위 5%의 loss 값 추출(높은 OOD에 해당하는 loss 값)
    test_ll = test_ll.cpu().detach().numpy()

    if args.ood_dataset:
        ood_ll = test(epoch, net, ood_testloader, device, loss_fn, mode='ood')  # LL (Not NLL)
        ood_ll_percentile = get_percentile(ood_ll)
        ood_ll = ood_ll.cpu().detach().numpy()
        
        # AUC-ROC
        n_ood, n_test = len(ood_ll), len(test_ll)
        lls = np.hstack([ood_ll, test_ll])
        targets = np.ones((n_ood + n_test,), dtype=int)
        targets[:n_ood] = 0
        score = roc_auc_score(targets, lls)
        if use_wandb:
            wandb.log({'ood/roc_auc': score})
        

    # plotting likelihood hists
    fig = plt.figure(figsize=(8, 8))
    sns.distplot(test_ll[test_ll > test_ll_percentile], label='test')  # 너무 예외적으로 낮게 뽑힌 Loss는 제외
    if args.ood_dataset:
        sns.distplot(ood_ll[ood_ll > ood_ll_percentile], label='OOD')
    plt.legend()
    fig.canvas.draw()
    hist_img = torch.tensor(np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep=''))
    hist_img = torch.tensor(hist_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))).transpose(0, 2).transpose(1, 2)
    os.makedirs(os.path.join(args.ckptdir,f'LL_histogram'), exist_ok=True)
    plt.savefig(os.path.join(args.ckptdir,f'LL_histogram/{epoch}.png'))
    if use_wandb:
        wandb.log({f"LL_histogram" : [wandb.Image(os.path.join(args.ckptdir,f'LL_histogram/{epoch}.png'))]})

    # Save checkpoint
    if (epoch % args.save_freq == 0):
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'epoch': epoch,
            'lls': lls,
            'targets' : targets,
            'ood/roc_auc' : score,
        }
        os.makedirs(args.ckptdir, exist_ok=True)
        torch.save(state, os.path.join(args.ckptdir, str(epoch)+'.pt'))
        


EPOCH : 0


  0%|          | 0/50000 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 10.75 GiB total capacity; 2.58 GiB already allocated; 13.12 MiB free; 2.62 GiB reserved in total by PyTorch)

# Appendix

## Randomcrop
- [Reference - torchvision official doc](https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py)

In [None]:
# sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png"

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)

def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [None]:
print(orig_img.size)
orig_img

In [None]:
cropper = T.RandomCrop(size=(256, 256), padding=4)
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

In [None]:
cropper = T.RandomCrop(size=(256, 256), padding=200)
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

In [None]:
cropper = T.RandomCrop(size=(256, 256), padding=100)
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

In [None]:
cropper = T.RandomCrop(size=(128, 128), padding=4)
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

In [None]:
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

## Gaussian Prior

In [None]:
prior = distributions.MultivariateNormal(torch.zeros(2).to(device),
                                         torch.eye(2).to(device))

prior_examples = prior.sample(sample_shape=[5_000])
exmp_df = pd.DataFrame({'x':prior_examples[:,0].to('cpu'), 'y':prior_examples[:,1].to('cpu')})

plt.figure(figsize=(10,5))
sns.scatterplot(x=exmp_df['x'], y=exmp_df['y'])
plt.show()