In [5]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
sys.path.append('../')

In [6]:
from tqdm import tqdm
from functools import partial

import jax
import jax.numpy as jnp
from flax.training import checkpoints
from flax import jax_utils

from giung2.data.build import build_dataloaders
from giung2.models.resnet import FlaxResNet
from giung2.models.layers import ConvDropFilter
from giung2.metrics import evaluate_acc, evaluate_nll

CPU = jax.devices('cpu')[0]

In [7]:
class DotDict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

config = DotDict()
config.data_root = '../data/'
config.data_augmentation = 'standard'
config.data_proportional = 1.0
config.optim_bs = 80
config.drop_rate = 0.01

In [8]:
config.data_name = 'CIFAR10_x32'
CKPT = '../save/CIFAR10_x32/R20x1-BN-ReLU/Dropout/bs-0256_ne-0500_lr-0.03_mo-0.90_wd-0.0030_drop-0.01_fp32/42/best_acc.ckpt'
M = 30

# build dataloaders
dataloaders = build_dataloaders(config)

# build model
model = FlaxResNet(
    depth        = 20,
    widen_factor = 1,
    dtype        = jnp.float32,
    pixel_mean   = (0.49, 0.48, 0.44),
    pixel_std    = (0.2, 0.2, 0.2),
    num_classes  = dataloaders['num_classes'],
    conv         = partial(ConvDropFilter, use_bias=False,
                           kernel_init=jax.nn.initializers.he_normal(),
                           bias_init=jax.nn.initializers.zeros,
                           drop_rate=config.drop_rate))

# initialize model
def initialize_model(key, model):
    @jax.jit
    def init(*args):
        return model.init(*args)
    return init({'params': key}, jnp.ones(dataloaders['image_shape'], model.dtype))
initialize_model(jax.random.PRNGKey(0), model)

# load pre-trained checkpoint
ckpt = checkpoints.restore_checkpoint(CKPT, target=None)

# define predict function
def predict(images, params, image_stats, batch_stats):
    rngs = jax.random.split(jax.random.PRNGKey(0), M)
    return jnp.stack([model.apply({
            'params': params,
            'image_stats': image_stats,
            'batch_stats': batch_stats,
        }, images, rngs={'dropout': rng}, mutable='intermediates', use_running_average=True, deterministic=False
        )[1]['intermediates']['cls.logit'][0] for rng in rngs])
_predict = jax.pmap(partial(predict, params=ckpt['params'], image_stats=ckpt['image_stats'], batch_stats=ckpt['batch_stats']))

# make predictions
tst_logits = []
tst_labels = []
tst_loader = jax_utils.prefetch_to_device(dataloaders['tst_loader'](rng=None), size=2)
for batch_idx, batch in tqdm(enumerate(tst_loader, start=1)):
    _logits, _labels = _predict(batch['images']), batch['labels']
    tst_logits.append(jax.device_put(_logits.transpose(0, 2, 1, 3).reshape(-1, M, dataloaders['num_classes']), CPU))
    tst_labels.append(jax.device_put(_labels.reshape(-1), CPU))
tst_logits = jnp.concatenate(tst_logits)
tst_labels = jnp.concatenate(tst_labels)

# evaluate predictions
_confidences = jnp.mean(jax.nn.softmax(tst_logits, axis=-1), axis=1)
_true_labels = tst_labels
print('{:.4f}'.format(evaluate_acc(_confidences, _true_labels, log_input=False)),
      '{:.4f}'.format(evaluate_nll(_confidences, _true_labels, log_input=False)))

125it [00:55,  2.25it/s]


0.9313 0.2163


In [9]:
config.data_name = 'CIFAR10_x32'
CKPT = '../save/CIFAR10_x32/R20x4-BN-ReLU/Dropout/bs-0256_ne-0500_lr-0.01_mo-0.90_wd-0.0030_drop-0.01_fp32/42/best_acc.ckpt'
M = 30

# build dataloaders
dataloaders = build_dataloaders(config)

# build model
model = FlaxResNet(
    depth        = 20,
    widen_factor = 4,
    dtype        = jnp.float32,
    pixel_mean   = (0.49, 0.48, 0.44),
    pixel_std    = (0.2, 0.2, 0.2),
    num_classes  = dataloaders['num_classes'],
    conv         = partial(ConvDropFilter, use_bias=False,
                           kernel_init=jax.nn.initializers.he_normal(),
                           bias_init=jax.nn.initializers.zeros,
                           drop_rate=config.drop_rate))

# initialize model
def initialize_model(key, model):
    @jax.jit
    def init(*args):
        return model.init(*args)
    return init({'params': key}, jnp.ones(dataloaders['image_shape'], model.dtype))
initialize_model(jax.random.PRNGKey(0), model)

# load pre-trained checkpoint
ckpt = checkpoints.restore_checkpoint(CKPT, target=None)

# define predict function
def predict(images, params, image_stats, batch_stats):
    rngs = jax.random.split(jax.random.PRNGKey(0), M)
    return jnp.stack([model.apply({
            'params': params,
            'image_stats': image_stats,
            'batch_stats': batch_stats,
        }, images, rngs={'dropout': rng}, mutable='intermediates', use_running_average=True, deterministic=False
        )[1]['intermediates']['cls.logit'][0] for rng in rngs])
_predict = jax.pmap(partial(predict, params=ckpt['params'], image_stats=ckpt['image_stats'], batch_stats=ckpt['batch_stats']))

# make predictions
tst_logits = []
tst_labels = []
tst_loader = jax_utils.prefetch_to_device(dataloaders['tst_loader'](rng=None), size=2)
for batch_idx, batch in tqdm(enumerate(tst_loader, start=1)):
    _logits, _labels = _predict(batch['images']), batch['labels']
    tst_logits.append(jax.device_put(_logits.transpose(0, 2, 1, 3).reshape(-1, M, dataloaders['num_classes']), CPU))
    tst_labels.append(jax.device_put(_labels.reshape(-1), CPU))
tst_logits = jnp.concatenate(tst_logits)
tst_labels = jnp.concatenate(tst_labels)

# evaluate predictions
_confidences = jnp.mean(jax.nn.softmax(tst_logits, axis=-1), axis=1)
_true_labels = tst_labels
print('{:.4f}'.format(evaluate_acc(_confidences, _true_labels, log_input=False)),
      '{:.4f}'.format(evaluate_nll(_confidences, _true_labels, log_input=False)))

125it [01:17,  1.61it/s]


0.9529 0.1528


In [10]:
config.data_name = 'CIFAR100_x32'
CKPT = '../save/CIFAR100_x32/R20x1-BN-ReLU/Dropout/bs-0256_ne-0500_lr-0.30_mo-0.90_wd-0.0001_drop-0.03_fp32/42/best_acc.ckpt'
M = 30

# build dataloaders
dataloaders = build_dataloaders(config)

# build model
model = FlaxResNet(
    depth        = 20,
    widen_factor = 1,
    dtype        = jnp.float32,
    pixel_mean   = (0.49, 0.48, 0.44),
    pixel_std    = (0.2, 0.2, 0.2),
    num_classes  = dataloaders['num_classes'],
    conv         = partial(ConvDropFilter, use_bias=False,
                           kernel_init=jax.nn.initializers.he_normal(),
                           bias_init=jax.nn.initializers.zeros,
                           drop_rate=config.drop_rate))

# initialize model
def initialize_model(key, model):
    @jax.jit
    def init(*args):
        return model.init(*args)
    return init({'params': key}, jnp.ones(dataloaders['image_shape'], model.dtype))
initialize_model(jax.random.PRNGKey(0), model)

# load pre-trained checkpoint
ckpt = checkpoints.restore_checkpoint(CKPT, target=None)

# define predict function
def predict(images, params, image_stats, batch_stats):
    rngs = jax.random.split(jax.random.PRNGKey(0), M)
    return jnp.stack([model.apply({
            'params': params,
            'image_stats': image_stats,
            'batch_stats': batch_stats,
        }, images, rngs={'dropout': rng}, mutable='intermediates', use_running_average=True, deterministic=False
        )[1]['intermediates']['cls.logit'][0] for rng in rngs])
_predict = jax.pmap(partial(predict, params=ckpt['params'], image_stats=ckpt['image_stats'], batch_stats=ckpt['batch_stats']))

# make predictions
tst_logits = []
tst_labels = []
tst_loader = jax_utils.prefetch_to_device(dataloaders['tst_loader'](rng=None), size=2)
for batch_idx, batch in tqdm(enumerate(tst_loader, start=1)):
    _logits, _labels = _predict(batch['images']), batch['labels']
    tst_logits.append(jax.device_put(_logits.transpose(0, 2, 1, 3).reshape(-1, M, dataloaders['num_classes']), CPU))
    tst_labels.append(jax.device_put(_labels.reshape(-1), CPU))
tst_logits = jnp.concatenate(tst_logits)
tst_labels = jnp.concatenate(tst_labels)

# evaluate predictions
_confidences = jnp.mean(jax.nn.softmax(tst_logits, axis=-1), axis=1)
_true_labels = tst_labels
print('{:.4f}'.format(evaluate_acc(_confidences, _true_labels, log_input=False)),
      '{:.4f}'.format(evaluate_nll(_confidences, _true_labels, log_input=False)))

125it [00:52,  2.39it/s]


0.6938 1.1213


In [11]:
config.data_name = 'CIFAR100_x32'
CKPT = '../save/CIFAR100_x32/R20x4-BN-ReLU/Dropout/bs-0256_ne-0500_lr-0.03_mo-0.90_wd-0.0030_drop-0.03_fp32/42/best_acc.ckpt'
M = 30

# build dataloaders
dataloaders = build_dataloaders(config)

# build model
model = FlaxResNet(
    depth        = 20,
    widen_factor = 4,
    dtype        = jnp.float32,
    pixel_mean   = (0.49, 0.48, 0.44),
    pixel_std    = (0.2, 0.2, 0.2),
    num_classes  = dataloaders['num_classes'],
    conv         = partial(ConvDropFilter, use_bias=False,
                           kernel_init=jax.nn.initializers.he_normal(),
                           bias_init=jax.nn.initializers.zeros,
                           drop_rate=config.drop_rate))

# initialize model
def initialize_model(key, model):
    @jax.jit
    def init(*args):
        return model.init(*args)
    return init({'params': key}, jnp.ones(dataloaders['image_shape'], model.dtype))
initialize_model(jax.random.PRNGKey(0), model)

# load pre-trained checkpoint
ckpt = checkpoints.restore_checkpoint(CKPT, target=None)

# define predict function
def predict(images, params, image_stats, batch_stats):
    rngs = jax.random.split(jax.random.PRNGKey(0), M)
    return jnp.stack([model.apply({
            'params': params,
            'image_stats': image_stats,
            'batch_stats': batch_stats,
        }, images, rngs={'dropout': rng}, mutable='intermediates', use_running_average=True, deterministic=False
        )[1]['intermediates']['cls.logit'][0] for rng in rngs])
_predict = jax.pmap(partial(predict, params=ckpt['params'], image_stats=ckpt['image_stats'], batch_stats=ckpt['batch_stats']))

# make predictions
tst_logits = []
tst_labels = []
tst_loader = jax_utils.prefetch_to_device(dataloaders['tst_loader'](rng=None), size=2)
for batch_idx, batch in tqdm(enumerate(tst_loader, start=1)):
    _logits, _labels = _predict(batch['images']), batch['labels']
    tst_logits.append(jax.device_put(_logits.transpose(0, 2, 1, 3).reshape(-1, M, dataloaders['num_classes']), CPU))
    tst_labels.append(jax.device_put(_labels.reshape(-1), CPU))
tst_logits = jnp.concatenate(tst_logits)
tst_labels = jnp.concatenate(tst_labels)

# evaluate predictions
_confidences = jnp.mean(jax.nn.softmax(tst_logits, axis=-1), axis=1)
_true_labels = tst_labels
print('{:.4f}'.format(evaluate_acc(_confidences, _true_labels, log_input=False)),
      '{:.4f}'.format(evaluate_nll(_confidences, _true_labels, log_input=False)))

125it [01:14,  1.68it/s]


0.7856 0.8273
