In [None]:
%reload_ext autoreload 
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
# os.environ["XLA_FLAGS"] = "--xla_gpu_strict_conv_algorithm_picker=false"

In [None]:
import jax
import jax.numpy as jnp
import hamux as hmx
import treex as tx
from flax import linen as nn # For initializers
import optax
import jax.tree_util as jtu
from typing import *
import matplotlib.pyplot as plt
from dataclasses import dataclass


In [None]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [None]:
class TrainState(tx.Module):
    model: tx.Module
    optimizer: tx.Optimizer
    apply_fn: Callable
    filter_betas: bool
    rng: jnp.ndarray = tx.Rng.node()
    eval_rng: jnp.ndarray = tx.Rng.node()

    def __init__(
        self, model, optimizer, apply_fn, rng, filter_betas=False, do_normal_init=False
    ):
        self.filter_betas = filter_betas
        self.model = model
        self.optimizer = tx.Optimizer(optimizer).init(self.params)
        self.apply_fn = apply_fn
        self.rng, self.eval_rng = jax.random.split(rng)
        self.do_normal_init = do_normal_init

    @property
    def params(self):
        if self.filter_betas:
            return self.model.filter(lambda x: "beta" not in x.name)
        return self.model.filter(tx.Parameter)

    def apply_updates(self, grads):
        new_params = self.optimizer.update(grads, self.params)
        self.model = self.model.merge(new_params)
        return self


def cross_entropy_loss(*, probs, labels):
    n_classes = probs.shape[-1]
    labels_onehot = jax.nn.one_hot(labels, num_classes=n_classes)
    smoothed_labels = (0.1 / n_classes + labels_onehot)
    smoothed_labels = smoothed_labels / jnp.abs(smoothed_labels).sum(-1, keepdims=True)

    stable_probs = (probs + 1e-6) / (1+(1e-6)*n_classes)
    loss = -jnp.sum(smoothed_labels * jnp.log(stable_probs), axis=-1).mean()
    return loss


def compute_metrics(*, probs, labels):
    loss = cross_entropy_loss(probs=probs, labels=labels)
    accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
    metrics = {
      "probs_min": probs.min(),
      'probs_max': probs.max(),
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics


@jax.jit
def train_step(state, batch):
    if state.do_normal_init:
        rng, state.rng = jax.random.split(state.rng)
    else:
        rng = None

    def loss_fn(params):
        state.model = state.model.merge(params)
        x = batch["image"]
        probs = state.apply_fn(state.model, x, rng=rng)
        loss = cross_entropy_loss(probs=probs, labels=batch["label"])
        return loss, (probs, state)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (probs, state)), grads = grad_fn(state.params)

    state = state.apply_updates(grads)
    metrics = compute_metrics(probs=probs, labels=batch["label"])
    return state, metrics

@jax.jit
def eval_step(state, batch):
    x = batch["image"]
    if state.do_normal_init:
        rng = state.eval_rng
    else:
        rng = None
    probs = state.apply_fn(state.model, x, rng=rng)
    return compute_metrics(probs=probs, labels=batch['label'])

def train_epoch(state, train_dl, epoch):
    """Train for a single epoch."""
    batch_metrics = []
    bs = train_dl.batch_size
    for i, batch in enumerate(tqdm(train_dl, leave=False)):
        batch = {"image": jnp.array(batch[0]), "label": jnp.array(batch[1])}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }
    return state, epoch_metrics_np["loss"], epoch_metrics_np["accuracy"]


def eval_model(state, test_dl):
    batch_metrics = []

    for i, batch in enumerate(test_dl):
        batch = {"image": jnp.array(batch[0]), "label": jnp.array(batch[1])}

        metrics = eval_step(state, batch)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    summary = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]
    }

    return summary["loss"], summary["accuracy"]

In [None]:
from hamux.datasets import *

dl_args = DataloadingArgs(
    dataset="torch/CIFAR10",
    # aa="rand",
    aa=None,
    reprob=0.2,
    vflip=0.0,
    hflip=0.5,
    scale=(0.2, 1.0),
    batch_size=args.batch_size,
    color_jitter=0.5,
    validation_batch_size=2 * args.batch_size,
)
data_config = DataConfigCIFAR10(input_size=(3, 32, 32))

train_dl, eval_dl = create_dataloaders(dl_args, data_config)

Files already downloaded and verified
Files already downloaded and verified
