In [None]:
import math
import os
import pickle
from collections import defaultdict
from typing import Any

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torchvision
import torchvision.transforms as transforms
import tqdm
from flax.training import train_state
from torch.utils.data import DataLoader


class RunningMoments:
    """
    Tracks running mean and variance
    Adapted from github.com/MadryLab/implementation-matters, which took it from
    github.com/joschu/modular_rl. Math in johndcook.com/blog/standard_deviation
    """

    def __init__(self):
        self.n = 0
        self.m = 0
        self.s = 0

    def push(self, x):
        assert isinstance(x, float) or isinstance(x, int)
        self.n += 1
        if self.n == 1:
            self.m = x
        else:
            old_m = self.m
            self.m = old_m + (x - old_m) / self.n
            self.s = self.s + (x - old_m) * (x - self.m)

    def mean(self):
        return self.m

    def std(self):
        if self.n > 1:
            return math.sqrt(self.s / (self.n - 1))
        else:
            return self.m


class Logger:
    def __init__(self):
        self._buffer_data = defaultdict(RunningMoments)
        self.cumulative_data = defaultdict(list)
        self.seen_plot_directories = set()

    # log metrics, used once per epoch
    def log(self, metrics=None, **kwargs):
        metrics = {} if metrics is None else metrics
        for k, v in {**metrics, **kwargs}.items():
            self.cumulative_data[k].append(v)

    # push metrics logged many times per epoch, e.g. loss, means and stds computed
    def push(self, metrics=None, **kwargs):
        metrics = {} if metrics is None else metrics
        for k, v in {**metrics, **kwargs}.items():
            self._buffer_data[k].push(v)

    def step(self):
        for k, v in self._buffer_data.items():
            self.cumulative_data[k + "_mean"].append(v.mean())
            self.cumulative_data[k + "_std"].append(v.std())
        self._buffer_data.clear()

    def tune_report(self):
        from ray import tune

        tune.report(**{k: v[-1] for k, v in self.cumulative_data.items()})

    def air_report(self, **kwargs):
        from ray.air import session

        session.report({k: v[-1] for k, v in self.cumulative_data.items()}, **kwargs)

    def save(self, filename):
        if not filename.endswith(".pickle"):
            filename = filename + ".pickle"
        with open(filename, "wb") as f:
            pickle.dump(self, f)

    def generate_plots(self, dirname="plotgen"):
        import matplotlib
        import matplotlib.pyplot as plt
        import seaborn as sns

        matplotlib.use("Agg")
        sns.set_theme()

        if dirname not in self.seen_plot_directories:
            os.makedirs(dirname, exist_ok=True)

            for filename in os.listdir(dirname):
                file_path = os.path.join(dirname, filename)
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
            self.seen_plot_directories.add(dirname)

        for k, v in self.cumulative_data.items():
            if k.endswith("_std"):
                continue

            fig, ax = plt.subplots()

            x = np.arange(len(self.cumulative_data[k]))
            v = np.array(v)
            if k.endswith("_mean"):
                name = k[:-5]

                (line,) = ax.plot(x, v, label=k)
                stds = np.array(self.cumulative_data[name + "_std"])
                ax.fill_between(
                    x, v - stds, v + stds, color=line.get_color(), alpha=0.15
                )
            else:
                name = k
                (line,) = ax.plot(x, v)
            ax.scatter(x, v, color=line.get_color())

            fig.suptitle(name)
            fig.savefig(os.path.join(dirname, name))
            plt.close(fig)

    def convergence(self, key, p=0.98):
        """Estimates the degree to which some metric has converged.
        A custom metric by me (Jerry). Close to zero when the metric is clearly trending
        upwards or downwards, close to one when changes in the metric seem to be
        dominated by noise. Intended for debugging purposes, not for scientific usage.
        p controls the degree to which this metric weights recently measured values.
        p = 0 results in a uniform weighting, independent of time. More and more weight
        is placed on the last few values as p approaches 1.
        """
        assert key in self.cumulative_data

        data = self.cumulative_data[key]
        if len(data) <= 1:
            return 0

        diffs = np.array([data[i + 1] - data[i] for i in range(len(data) - 1)])
        w = np.power((1 - p), (1 - np.linspace(0, 1, num=len(diffs))))
        w = w / np.sum(w)

        m = np.sum(w * diffs)
        v = np.sum(w * diffs * diffs)

        return 1 - abs(m) / math.sqrt(v + 1e-8)


In [None]:
class TrainState(train_state.TrainState):
    batch_stats: Any


class CifarResnet(nn.Module):
    n: int

    @nn.compact
    def __call__(self, x, train: bool):
        x = nn.Conv(16, kernel_size=(3, 3))(x)
        x = nn.BatchNorm(use_running_average=not train)(x)

        for _ in range(self.n):
            out = nn.Conv(16, kernel_size=(3, 3))(x)
            out = nn.relu(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            out = nn.Conv(16, kernel_size=(3, 3))(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            x = nn.relu(out + x)

        for _ in range(self.n):
            out = nn.Conv(32, kernel_size=(3, 3), strides=(2, 2))(x)
            out = nn.relu(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            out = nn.Conv(32, kernel_size=(3, 3))(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            residual = nn.Conv(32, kernel_size=(3, 3), strides=(2, 2))(x)
            x = nn.relu(out + residual)

        for _ in range(self.n):
            out = nn.Conv(64, kernel_size=(3, 3), strides=(2, 2))(x)
            out = nn.relu(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            out = nn.Conv(64, kernel_size=(3, 3))(out)
            out = nn.BatchNorm(use_running_average=not train)(out)
            residual = nn.Conv(64, kernel_size=(3, 3), strides=(2, 2))(x)
            x = nn.relu(out + residual)

        x = x.mean(axis=(-2, -3))
        x = nn.Dense(10)(x)
        return x


def loss_fn(params, ts, images, labels):
    logits, updates = ts.apply_fn(
        {"params": params, "batch_stats": ts.batch_stats},
        images,
        train=True,
        mutable=["batch_stats"],
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, (logits, updates)


@jax.jit
def inner_step(ts: TrainState, images, labels):
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits, updates)), grads = grad_fn(ts.params, ts, images, labels)
    ts = ts.apply_gradients(grads=grads)
    ts = ts.replace(batch_stats=updates["batch_stats"])
    metrics = {
        "loss": loss,
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }
    return ts, metrics


@jax.jit
def eval_step(ts: TrainState, images, labels):
    logits = ts.apply_fn(
        {"params": ts.params, "batch_stats": ts.batch_stats}, images, train=False
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    metrics = {
        "loss": loss,
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }
    return metrics


def standard_trial(rng, train_loader, test_loader, config):
    model = CifarResnet(n=3)
    variables = model.init(rng, np.zeros((32, 32, 3)), train=True)
    ts = TrainState.create(
        apply_fn=model.apply,
        params=variables["params"],
        batch_stats=variables["batch_stats"],
        tx=optax.chain(
            optax.add_decayed_weights(
                weight_decay=config["weight_decay"],
                mask=jax.tree_map(lambda x: x.ndim != 1, variables["params"]),
            ),
            optax.sgd(config["lr"], momentum=0.9),
        ),
    )
    logger = Logger()

    for _ in range(100):
        for images, labels in tqdm.tqdm(train_loader):
            images = transforms.RandomCrop((32, 32), padding=4)(images)
            images = np.asarray(images).transpose((0, 2, 3, 1))
            labels = np.asarray(labels)

            ts, train_metrics = inner_step(ts, images, labels)
            train_metrics = jax.tree_map(lambda x: x.item(), train_metrics)
            logger.push({"train_" + k: v for k, v in train_metrics.items()})

        for images, labels in test_loader:
            images = np.asarray(images).transpose((0, 2, 3, 1))
            labels = np.asarray(labels)

            test_metrics = eval_step(ts, images, labels)
            test_metrics = jax.tree_map(lambda x: x.item(), test_metrics)
            logger.push({"test_" + k: v for k, v in test_metrics.items()})

        logger.step()
        logger.generate_plots()


def outer_loss(weight_decay, ts, images, labels, valid_images, valid_labels):
    assert "weight_decay" in ts.opt_state.hyperparams
    ts.opt_state.hyperparams["weight_decay"] = weight_decay

    ts = inner_step(ts, images, labels)

    loss = loss_fn(ts.params, ts, valid_images, valid_labels)
    return loss, ts


@jax.jit
def outer_step(weight_decay, ts: TrainState, images, labels, valid_images, valid_labels):
    (loss, ts), grad = jax.value_and_grad(outer_loss, has_aux=True)(weight_decay, ts, images, labels, valid_images, valid_labels)
    weight_decay = weight_decay - 0.01 * grad  # lol hardcoded lr whatever
    return weight_decay, ts


def meta_reg_trial(rng, train_loader, test_loader, config):
    model = CifarResnet(n=3)
    variables = model.init(rng, np.zeros((32, 32, 3)), train=True)
    ts = TrainState.create(
        apply_fn=model.apply,
        params=variables["params"],
        batch_stats=variables["batch_stats"],
        tx=optax.chain(
            optax.add_decayed_weights(
                weight_decay=config["weight_decay"],
                mask=jax.tree_map(lambda x: x.ndim != 1, variables["params"]),
            ),
            optax.sgd(config["lr"], momentum=0.9),
        ),
    )
    logger = Logger()
    weight_decay = 1e-4

    for _ in range(100):
        for images, labels in tqdm.tqdm(train_loader):
            images = transforms.RandomCrop((32, 32), padding=4)(images)
            images = np.asarray(images).transpose((0, 2, 3, 1))
            labels = np.asarray(labels)

            ts, train_metrics = inner_step(ts, images, labels)
            train_metrics = jax.tree_map(lambda x: x.item(), train_metrics)
            logger.push({"train_" + k: v for k, v in train_metrics.items()})

        for images, labels in test_loader:
            images = np.asarray(images).transpose((0, 2, 3, 1))
            labels = np.asarray(labels)

            test_metrics = eval_step(ts, images, labels)
            test_metrics = jax.tree_map(lambda x: x.item(), test_metrics)
            logger.push({"test_" + k: v for k, v in test_metrics.items()})

        logger.step()
        logger.generate_plots()


def main():
    rng = jax.random.PRNGKey(42)

    train_dataset = torchvision.datasets.CIFAR10(
        "data/", transform=torchvision.transforms.ToTensor(), download=True
    )
    test_dataset = torchvision.datasets.CIFAR10(
        "data/",
        transform=torchvision.transforms.ToTensor(),
        train=False,
    )

    train_loader = DataLoader(
        train_dataset, batch_size=100, shuffle=True, drop_last=True
    )
    test_loader = DataLoader(test_dataset, batch_size=100)
    standard_trial(rng, train_loader, test_loader, {"lr": 1e-2, "weight_decay": 1e-4})


if __name__ == "__main__":
    main()
