In [1]:
import argparse
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils import data
import torch
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state 

from utils import get_accuracy
from lars import LARSWrapper
import cifar_100
from logger import log_metrics as logger
import torch_trainer as trainer
from cifar_resnet import resnet18
import cifar_100

#python helper inputs
import os
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
import wandb
import pytorch_lightning as pl
import time



dataset_args = {
                 'crop_size': 32,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': .8, 
                 'gray_scale_prob': 0.2, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': .5, 
                 'min_scale': 0.16, 
                 'max_scale': 0.9}
val_dataset_args = {
                 'crop_size': 32,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': 0, 
                 'gray_scale_prob': 0, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': 0, 
                 'min_scale': 0.9, 
                 'max_scale': 1}




def prepare_data(dataset_args, val_dataset_args):
    
    train_dataset = cifar_100.CIFAR_100_transformations(train = True, **dataset_args)
    dataloader = DataLoader(
        train_dataset,
        shuffle = True,
        batch_size=256,
        num_workers=8,
        pin_memory=False,
        drop_last=True,
        persistent_workers=True
    )

    val_dataset = cifar_100.CIFAR_100_transformations(views = 1, train = True, **val_dataset_args)
    val_dataloader = DataLoader(
        val_dataset,
        shuffle = True,
        batch_size=256,
        num_workers=8,
        pin_memory=False,
        drop_last=True,
        persistent_workers=False
    )
    return train_dataset, dataloader, val_dataloader


def create_params(model, example, rng):
    model = model
    batch = example  # (N, H, W, C) format
    params = model.init(rng, batch)['params']
    return params


def create_optimizer(optimizer, lr, wd):
    optimizer = optimizer(lr)
    return optimizer

def cross_entropy_loss(logits, labels):
    return optax.softmax_cross_entropy(logits=logits, labels=labels).mean()


def create_train_state(rng, optimizer):
    """Creates initial `TrainState`."""
    model = resnet18()
    batch = jnp.ones((4, 32, 32, 3))  # (N, H, W, C) format
    params = model.init(jax.random.PRNGKey(0), batch)['params']
    tx = optimizer
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params, data):
    logits, _ = model.apply({'params': params}, data['image0'], mutable=['batchstats'])
    loss = jnp.mean(jax.vmap(cross_entropy_loss)(logits=logits, labels=data['label']), axis= 0)
    return loss, logits



@jax.jit
def training_step(state, data):
    
    def loss_fn(params, data):
        logits, _ = model.apply({'params': params}, data['image0'], mutable=['batch_stats'])
        loss = jnp.mean(jax.vmap(cross_entropy_loss)(logits=logits, labels=data['label']), axis= 0)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params, data)
    #state = state.apply_gradients(grads=grads)


   # metrics['total'] += data['image0'].shape[0]
    #metrics['Accuracy'] += acc1
    #metrics['Accuracy Top 5'] += acc5

    return state#, metrics

In [9]:
train_dataset, dataloader, val_dataloader = prepare_data(dataset_args, val_dataset_args)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 0.1
momentum = 0.9
optimizer = optax.adamw(.001)
state = create_train_state(init_rng, optimizer)
del init_rng  # Must not be used anymore.
model = resnet18()

import time
now = time.time()

for i, data in enumerate(dataloader):
    data['image0'] = jnp.array(data['image0'].permute(0, 2, 3, 1))
    data['label'] = jnp.array(data['label'])
    state = training_step(state, data)
    #logits = resnet18().apply({'params': state.params}, data['image0'], mutable=['batch_stats'])
    if i % 10 == 0: print(i, time.time() - now)

0 3.564973831176758
10 3.9388911724090576
20 4.651685953140259
30 5.6907055377960205
40 6.692582845687866
50 8.019368171691895
60 8.91062617301941
70 9.836599349975586
80 10.723158359527588
90 11.835393190383911
100 12.726295232772827
110 13.458921670913696
120 14.38868522644043
130 15.744547367095947
140 16.639466524124146
150 17.463393211364746
160 18.32213807106018
170 19.504474878311157
180 20.501022338867188
190 21.012498378753662


In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from haiku_bn import BatchNorm


class MLP(hk.Module):
  """One hidden layer perceptron, with normalization."""

  def __init__(
      self,
  ):
    super().__init__()
    self._hidden_size = 10
    self._output_size = 1
    bn_config = dict()
    bn_config.setdefault("create_scale", True)
    bn_config.setdefault("create_offset", True)
    bn_config.setdefault("decay_rate", 0.999)
    self._bn_config = bn_config

  def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndarray:
    out = hk.Linear(output_size=self._hidden_size, with_bias=True)(inputs)
    out = BatchNorm(**self._bn_config)(out, is_training=is_training)
    out = jax.nn.relu(out)
    out = hk.Linear(output_size=self._output_size, with_bias=False)(out)
    return out    

def forward(x):
  module = MLP()
  return module(x, True)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
forward = hk.transform_with_state(forward)
params, state = forward.init(init_rng, jnp.ones((4, 3)))


In [2]:
print(state)

{'mlp/batch_norm/~/mean_ema': {'counter': DeviceArray(0., dtype=float32), 'hidden': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'average': DeviceArray([[ 0.34513226, -0.2222743 , -0.07778113, -0.05308276,
               0.3091463 , -0.3052339 , -1.1058168 ,  0.82687795,
               0.16870472, -0.6846972 ]], dtype=float32)}, 'mlp/batch_norm/~/var_ema': {'counter': DeviceArray(0., dtype=float32), 'hidden': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'average': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)}}


In [8]:
x = jax.random.normal(init_rng, (4, 3))
y = jnp.array([1.0, 2.1, 1.4, 1.2])
def loss_fn(state, params, x, y):
    logits, state = forward.apply(params, state, None, x)
    print(type(logits), type(y))
    return ((logits - y)**2).mean()
grad_fn = jax.grad(loss_fn)(state, params, x, y)


<class 'jaxlib.xla_extension.DeviceArray'> <class 'jaxlib.xla_extension.DeviceArray'>


In [3]:
print(state)

{'mlp/batch_norm/~/mean_ema': {'counter': DeviceArray(0, dtype=int32), 'hidden': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'average': DeviceArray([[ 0.34513226, -0.2222743 , -0.07778113, -0.05308276,
               0.3091463 , -0.3052339 , -1.1058168 ,  0.82687795,
               0.16870472, -0.6846972 ]], dtype=float32)}, 'mlp/batch_norm/~/var_ema': {'counter': DeviceArray(0, dtype=int32), 'hidden': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'average': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)}}
