<a href="https://colab.research.google.com/github/kodai-utsunomiya/memorization-and-generalization/blob/main/Disentangling_feature_and_lazy_training_in_deep_neural_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# アーキテクチャ

In [1]:
# pylint: disable=E1101, C, arguments-differ
"""
Defines three architectures:
- Fully connecetd `FC`
- Convolutional `CV`
- And a resnet `Wide_ResNet`
"""
import functools
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class FC(nn.Module):
    def __init__(self, d, h, L, act, bias=False):
        super().__init__()

        hh = d
        for i in range(L):
            W = torch.randn(h, hh)

            # next two line are here to avoid memory issue when computing the kernel
            n = max(1, 128 * 256 // hh)
            W = nn.ParameterList([nn.Parameter(W[j: j+n]) for j in range(0, len(W), n)])

            setattr(self, "W{}".format(i), W)
            if bias:
                self.register_parameter("B{}".format(i), nn.Parameter(torch.zeros(h)))
            hh = h

        self.register_parameter("W{}".format(L), nn.Parameter(torch.randn(1, hh)))
        if bias:
            self.register_parameter("B{}".format(L), nn.Parameter(torch.zeros(1)))

        self.L = L
        self.act = act
        self.bias = bias

    def forward(self, x):
        for i in range(self.L + 1):
            W = getattr(self, "W{}".format(i))

            if isinstance(W, nn.ParameterList):
                W = torch.cat(list(W))

            if self.bias:
                B = self.bias * getattr(self, "B{}".format(i))
            else:
                B = 0

            h = x.size(1)

            if i < self.L:
                x = x @ (W.t() / h ** 0.5)
                x = self.act(x + B)
            else:
                x = x @ (W.t() / h) + B

        return x.view(-1)


class CV(nn.Module):
    def __init__(self, d, h, L1, L2, act, h_base, fsz, pad, stride_first):
        super().__init__()

        h1 = d
        for i in range(L1):
            h2 = round(h * h_base ** i)
            for j in range(L2):
                W = nn.ParameterList([nn.Parameter(torch.randn(h1, fsz, fsz)) for _ in range(h2)])
                setattr(self, "W{}_{}".format(i, j), W)
                h1 = h2

        self.W = nn.Parameter(torch.randn(h1))

        self.L1 = L1
        self.L2 = L2
        self.act = act
        self.pad = pad
        self.stride_first = stride_first

    def forward(self, x):
        for i in range(self.L1):
            for j in range(self.L2):
                assert x.size(2) >= 5 and x.size(3) >= 5
                W = getattr(self, "W{}_{}".format(i, j))
                W = torch.stack(list(W))

                stride = 2 if j == 0 and (i > 0 or self.stride_first) else 1
                h = W[0].numel()
                x = nn.functional.conv2d(x, W / h ** 0.5, None, stride=stride, padding=self.pad)
                x = self.act(x)

        x = x.flatten(2).mean(2)

        W = self.W
        h = len(W)
        x = x @ (W / h)
        return x.view(-1)


class conv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, bias=True):
        super().__init__()

        w = torch.randn(out_planes, in_planes, kernel_size, kernel_size)
        n = max(1, 256**2 // w[0].numel())
        self.w = nn.ParameterList([nn.Parameter(w[j: j + n]) for j in range(0, len(w), n)])

        self.b = nn.Parameter(torch.zeros(out_planes)) if bias else None

        self.stride = stride
        self.padding = padding

    def forward(self, x):
        w = torch.cat(list(self.w))
        h = w[0].numel()
        return F.conv2d(x, w / h ** 0.5, self.b, self.stride, self.padding)

class wide_basic(nn.Module):
    def __init__(self, in_planes, planes, act, stride=1, mix_angle=45):
        super().__init__()
        self.conv1 = conv(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = conv(in_planes, planes, kernel_size=1, stride=stride, bias=True)

        self.act = act
        self.mix_angle = mix_angle

    def forward(self, x):
        out = self.conv1(self.act(x))
        out = self.conv2(self.act(out))
        cut = self.shortcut(x)

        a = self.mix_angle * math.pi / 180
        out = math.cos(a) * cut + math.sin(a) * out

        return out

class Wide_ResNet(nn.Module):
    def __init__(self, d, depth, h, act, num_classes, mix_angle=45):
        super().__init__()

        assert (depth % 6 == 4), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6

        nStages = [16, 16 * h, 32 * h, 64 * h]
        block = functools.partial(wide_basic, act=act, mix_angle=mix_angle)

        self.conv1 = conv(d, nStages[0], kernel_size=3, stride=1, padding=1, bias=True)
        self.in_planes = nStages[0]

        self.layer1 = self._wide_layer(block, nStages[1], n, stride=1)
        self.layer2 = self._wide_layer(block, nStages[2], n, stride=2)
        self.layer3 = self._wide_layer(block, nStages[3], n, stride=2)
        self.linear = nn.Parameter(torch.randn(num_classes, nStages[3]))
        self.bias = nn.Parameter(torch.zeros(num_classes))
        self.act = act

    def _wide_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, stride=stride))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.act(out)
        out = out.flatten(2).mean(2)

        h = self.linear.size(1)
        out = F.linear(out, self.linear / h, self.bias)

        if out.size(1) == 1:
            out = out.flatten(0)

        return out

# データセット

In [2]:
# pylint: disable=no-member, E1102, C
"""
- Load mnist or cifar10
- perform PCA
- shuffle the dataset
- split in train and test set in an equilibrated way (same amount of each classes)
"""
import functools

import torch


def pca(x, d, whitening):
    '''
    :param x: [P, ...]
    :return: [P, d]
    '''

    z = x.flatten(1)
    mu = z.mean(0)
    cov = (z - mu).t() @ (z - mu) / len(z)

    val, vec = cov.symeig(eigenvectors=True)
    val, idx = val.sort(descending=True)
    vec = vec[:, idx]

    u = (z - mu) @ vec[:, :d]
    if whitening:
        u.mul_(val[:d].rsqrt())
    else:
        u.mul_(val[:d].mean().rsqrt())

    return u


def get_binary_pca_dataset(dataset, p, d, whitening, seed=None, device=None):
    if seed is None:
        seed = torch.randint(2 ** 32, (), dtype=torch.long).item()

    x, y = get_normalized_dataset(dataset, seed)

    x = pca(x, d, whitening).to(device)
    y = (2 * (torch.arange(len(y)) % 2) - 1).type(x.dtype).to(device)

    xtr = x[:p]
    xte = x[p:]
    ytr = y[:p]
    yte = y[p:]

    return (xtr, ytr), (xte, yte)


def get_dataset(dataset, p, seed=None, device=None):
    if seed is None:
        seed = torch.randint(2 ** 32, (), dtype=torch.long).item()

    x, y = get_normalized_dataset(dataset, seed)

    x = x.to(device)
    y = y.to(device)

    xtr = x[:p]
    xte = x[p:]
    ytr = y[:p]
    yte = y[p:]

    return (xtr, ytr), (xte, yte)


def get_binary_dataset(dataset, p, seed=None, device=None):
    if seed is None:
        seed = torch.randint(2 ** 32, (), dtype=torch.long).item()

    x, y = get_normalized_dataset(dataset, seed)

    x = x.to(device)
    y = (2 * (torch.arange(len(y)) % 2) - 1).type(x.dtype).to(device)

    xtr = x[:p]
    xte = x[p:]
    ytr = y[:p]
    yte = y[p:]

    return (xtr, ytr), (xte, yte)


@functools.lru_cache(maxsize=2)
def get_normalized_dataset(dataset, seed):
    import torchvision
    from itertools import chain

    transform = torchvision.transforms.ToTensor()

    if dataset == "mnist":
        tr = torchvision.datasets.MNIST('~/.torchvision/datasets/MNIST', train=True, download=True, transform=transform)
        te = torchvision.datasets.MNIST('~/.torchvision/datasets/MNIST', train=False, transform=transform)
    elif dataset == "kmnist":
        tr = torchvision.datasets.KMNIST('~/.torchvision/datasets/KMNIST', train=True, download=True, transform=transform)
        te = torchvision.datasets.KMNIST('~/.torchvision/datasets/KMNIST', train=False, transform=transform)
    elif dataset == "emnist-letters":
        tr = torchvision.datasets.EMNIST('~/.torchvision/datasets/EMNIST', train=True, download=True, transform=transform, split='letters')
        te = torchvision.datasets.EMNIST('~/.torchvision/datasets/EMNIST', train=False, transform=transform, split='letters')
    elif dataset == "fashion":
        tr = torchvision.datasets.FashionMNIST('~/.torchvision/datasets/FashionMNIST', train=True, download=True, transform=transform)
        te = torchvision.datasets.FashionMNIST('~/.torchvision/datasets/FashionMNIST', train=False, transform=transform)
    elif dataset == "cifar10":
        tr = torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=True, download=True, transform=transform)
        te = torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=False, transform=transform)
    elif dataset == "cifar_catdog":
        tr = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=True, download=True, transform=transform) if y in [3, 5]]
        te = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=False, transform=transform) if y in [3, 5]]
    elif dataset == "cifar_shipbird":
        tr = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=True, download=True, transform=transform) if y in [8, 2]]
        te = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=False, transform=transform) if y in [8, 2]]
    elif dataset == "cifar_catplane":
        tr = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=True, download=True, transform=transform) if y in [3, 0]]
        te = [(x, y) for x, y in torchvision.datasets.CIFAR10('~/.torchvision/datasets/CIFAR10', train=False, transform=transform) if y in [3, 0]]
    else:
        raise ValueError("unknown dataset")

    dataset = list(tr) + list(te)
    dataset = [(x.type(torch.float64), int(y)) for x, y in dataset]
    classes = sorted({y for x, y in dataset})

    sets = [[(x, y) for x, y in dataset if y == i] for i in classes]

    torch.manual_seed(seed)
    sets = [
        [x[i] for i in torch.randperm(len(x))]
        for x in sets
    ]

    dataset = list(chain(*zip(*sets)))

    x = torch.stack([x for x, y in dataset])
    x = x - x.mean(0)
    x = (x[0].numel() ** 0.5) * x / x.flatten(1).norm(dim=1).view(-1, *(1,) * (x.dim() - 1))

    y = torch.tensor([y for x, y in dataset], dtype=torch.long)

    return x, y

# ダイナミクス

In [3]:
# pylint: disable=E1101, C
"""
This file implements a continuous version of momentum SGD
Dynamics that compares the angle of the gradient between steps and keep it small

- stop when margins are reached

It contains two implementation of the same dynamics:
1. `train_regular` for any kind of models
2. `train_kernel` only for linear models
"""
import copy
import itertools
import math
from time import perf_counter

import torch

# from hessian import gradient

#####################################################################################################
def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):
    r'''
    Compute the gradient of `outputs` with respect to `inputs`
    ```
    gradient(x.sum(), x)
    gradient((x * y).sum(), [x, y])
    ```
    '''
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(outputs, inputs, grad_outputs,
                                allow_unused=True,
                                retain_graph=retain_graph,
                                create_graph=create_graph)
    grads = [x if x is not None else torch.zeros_like(y) for x, y in zip(grads, inputs)]
    return torch.cat([x.contiguous().view(-1) for x in grads])
#####################################################################################################


def loglinspace(rate, step, end=None):
    t = 0
    while end is None or t <= end:
        yield t
        t = int(t + 1 + step * (1 - math.exp(-t * rate / step)))


class ContinuousMomentum(torch.optim.Optimizer):
    r"""Implements a continuous version of momentum.

    d/dt velocity = -1/tau (velocity + grad)
     or
    d/dt velocity = -mu/t (velocity + grad)

    d/dt parameters = velocity
    """

    def __init__(self, params, dt, tau):
        defaults = dict(dt=dt, tau=tau)
        super().__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            tau = group['tau']
            dt = group['dt']

            for p in group['params']:
                if p.grad is None:
                    continue

                param_state = self.state[p]
                if 't' not in param_state:
                    t = param_state['t'] = 0
                else:
                    t = param_state['t']

                if tau != 0:
                    if 'velocity' not in param_state:
                        v = param_state['velocity'] = torch.zeros_like(p.data)
                    else:
                        v = param_state['velocity']

                if tau > 0:
                    x = math.exp(-dt / tau)
                    v.mul_(x).add_(-(1 - x), p.grad.data)
                elif tau < 0:
                    mu = -tau
                    x = (t / (t + dt)) ** mu
                    v.mul_(x).add_(-(1 - x), p.grad.data)
                else:
                    v = -p.grad.data

                p.data.add_(dt, v)
                param_state['t'] += dt

        return loss


def make_step(f, optimizer, dt, grad):
    i = 0
    for p in f.parameters():
        n = p.numel()
        p.grad = grad[i: i + n].view_as(p)
        i += n

    for param_group in optimizer.param_groups:
        param_group['dt'] = dt

    optimizer.step()

    for p in f.parameters():
        p.grad = None


def train_regular(f0, x, y, tau, max_walltime, alpha, loss, subf0, max_dgrad=math.inf, max_dout=math.inf):
    f = copy.deepcopy(f0)

    with torch.no_grad():
        out0 = f0(x) if subf0 else 0

    dt = 1
    step_change_dt = 0
    optimizer = ContinuousMomentum(f.parameters(), dt=dt, tau=tau)

    checkpoint_generator = loglinspace(0.01, 100)
    checkpoint = next(checkpoint_generator)
    wall = perf_counter()
    t = 0
    converged = False

    out = f(x)
    grad = gradient(loss((out - out0) * y).mean(), f.parameters())

    for step in itertools.count():

        state = copy.deepcopy((f.state_dict(), optimizer.state_dict(), t))

        while True:
            make_step(f, optimizer, dt, grad)
            t += dt
            current_dt = dt

            new_out = f(x)
            new_grad = gradient(loss((new_out - out0) * y).mean(), f.parameters())

            dout = (out - new_out).mul(alpha).abs().max().item()
            if grad.norm() == 0 or new_grad.norm() == 0:
                dgrad = 0
            else:
                dgrad = (grad - new_grad).norm().pow(2).div(grad.norm() * new_grad.norm()).item()

            if dgrad < max_dgrad and dout < max_dout:
                if dgrad < 0.5 * max_dgrad and dout < 0.5 * max_dout:
                    dt *= 1.1
                break

            dt /= 10

            print("[{} +{}] [dt={:.1e} dgrad={:.1e} dout={:.1e}]".format(step, step - step_change_dt, dt, dgrad, dout), flush=True)
            step_change_dt = step
            f.load_state_dict(state[0])
            optimizer.load_state_dict(state[1])
            t = state[2]

        out = new_out
        grad = new_grad

        save = False

        if step == checkpoint:
            checkpoint = next(checkpoint_generator)
            assert checkpoint > step
            save = True

        if (alpha * (out - out0) * y >= 1).all() and not converged:
            converged = True
            save = True

        if save:
            state = {
                'step': step,
                'wall': perf_counter() - wall,
                't': t,
                'dt': current_dt,
                'dgrad': dgrad,
                'dout': dout,
                'norm': sum(p.norm().pow(2) for p in f.parameters()).sqrt().item(),
                'dnorm': sum((p0 - p).norm().pow(2) for p0, p in zip(f0.parameters(), f.parameters())).sqrt().item(),
                'grad_norm': grad.norm().item(),
            }

            yield f, state, converged

        if converged:
            break

        if perf_counter() > wall + max_walltime:
            break

        if torch.isnan(out).any():
            break



def train_kernel(ktrtr, ytr, tau, max_walltime, alpha, loss_prim, max_dgrad=math.inf, max_dout=math.inf):
    otr = ktrtr.new_zeros(len(ytr))
    velo = otr.clone()

    dt = 1
    step_change_dt = 0

    checkpoint_generator = loglinspace(0.01, 100)
    checkpoint = next(checkpoint_generator)
    wall = perf_counter()
    t = 0
    converged = False

    lprim = loss_prim(otr * ytr) * ytr
    grad = ktrtr @ lprim / len(ytr)

    for step in itertools.count():

        state = copy.deepcopy((otr, velo, t))

        while True:

            if tau > 0:
                x = math.exp(-dt / tau)
                velo.mul_(x).add_(-(1 - x), grad)
            elif tau < 0:
                mu = -tau
                x = (t / (t + dt)) ** mu
                velo.mul_(x).add_(-(1 - x), grad)
            else:
                velo.copy_(-grad)
            otr.add_(dt, velo)

            t += dt
            current_dt = dt

            lprim = loss_prim(otr * ytr) * ytr
            new_grad = ktrtr @ lprim / len(ytr)

            dout = velo.mul(dt * alpha).abs().max().item()
            if grad.norm() == 0 or new_grad.norm() == 0:
                dgrad = 0
            else:
                dgrad = (grad - new_grad).norm().pow(2).div(grad.norm() * new_grad.norm()).item()

            if dgrad < max_dgrad and dout < max_dout:
                if dgrad < 0.1 * max_dgrad and dout < 0.1 * max_dout:
                    dt *= 1.1
                break

            dt /= 10

            print("[{} +{}] [dt={:.1e} dgrad={:.1e} dout={:.1e}]".format(step, step - step_change_dt, dt, dgrad, dout), flush=True)
            step_change_dt = step
            otr.copy_(state[0])
            velo.copy_(state[1])
            t = state[2]

        grad = new_grad

        save = False

        if step == checkpoint:
            checkpoint = next(checkpoint_generator)
            assert checkpoint > step
            save = True

        if (alpha * otr * ytr >= 1).all() and not converged:
            converged = True
            save = True

        if save:
            state = {
                'step': step,
                'wall': perf_counter() - wall,
                't': t,
                'dt': current_dt,
                'dgrad': dgrad,
                'dout': dout,
                'grad_norm': grad.norm().item(),
            }

            yield otr, velo, grad, state, converged

        if converged:
            break

        if perf_counter() > wall + max_walltime:
            break

        if torch.isnan(otr).any():
            break

# カーネル

In [4]:
# pylint: disable=no-member, C, not-callable
"""
Computes the Gram matrix of a given model
"""

def compute_kernels(f, xtr, xte):
    # from hessian import gradient

    ktrtr = xtr.new_zeros(len(xtr), len(xtr))
    ktetr = xtr.new_zeros(len(xte), len(xtr))
    ktete = xtr.new_zeros(len(xte), len(xte))

    params = []
    current = []
    for p in sorted(f.parameters(), key=lambda p: p.numel(), reverse=True):
        current.append(p)
        if sum(p.numel() for p in current) > 2e9 // (8 * (len(xtr) + len(xte))):
            if len(current) > 1:
                params.append(current[:-1])
                current = current[-1:]
            else:
                params.append(current)
                current = []
    if len(current) > 0:
        params.append(current)

    for i, p in enumerate(params):
        print("[{}/{}] [len={} numel={}]".format(i, len(params), len(p), sum(x.numel() for x in p)), flush=True)

        jtr = xtr.new_empty(len(xtr), sum(u.numel() for u in p))  # (P, N~)
        jte = xte.new_empty(len(xte), sum(u.numel() for u in p))  # (P, N~)

        for j, x in enumerate(xtr):
            jtr[j] = gradient(f(x[None]), p)  # (N~)

        for j, x in enumerate(xte):
            jte[j] = gradient(f(x[None]), p)  # (N~)

        ktrtr.add_(jtr @ jtr.t())
        ktetr.add_(jte @ jtr.t())
        ktete.add_(jte @ jte.t())
        del jtr, jte

    return ktrtr, ktetr, ktete

# メイン（GPUに設定していることを前提）

In [5]:
# pylint: disable=C, R, bare-except, arguments-differ, no-member, undefined-loop-variable
import argparse
import math
import os
import subprocess
from functools import partial
from time import perf_counter

import torch

# from archi import CV, FC, Wide_ResNet
# from dataset import get_binary_dataset, get_binary_pca_dataset
# from dynamics import train_kernel, train_regular
# from kernels import compute_kernels


def loss_func(args, fy):
    if args.loss == 'softhinge':
        sp = partial(torch.nn.functional.softplus, beta=args.lossbeta)
        return sp(1 - args.alpha * fy) / args.alpha
    if args.loss == 'qhinge':
        return 0.5 * (1 - args.alpha * fy).relu().pow(2) / args.alpha


def loss_func_prime(args, fy):
    if args.loss == 'softhinge':
        return -torch.sigmoid(args.lossbeta * (1 - args.alpha * fy)).mul(args.lossbeta)
    if args.loss == 'qhinge':
        return -(1 - args.alpha * fy).relu()


class SplitEval(torch.nn.Module):
    def __init__(self, f, size):
        super().__init__()
        self.f = f
        self.size = size

    def forward(self, x):
        return torch.cat([self.f(x[i: i + self.size]) for i in range(0, len(x), self.size)])


def run_kernel(args, ktrtr, ktetr, ktete, f, xtr, ytr, xte, yte):
    assert args.f0 == 1

    dynamics = []

    tau = args.tau_over_h * args.h
    if args.tau_alpha_crit is not None:
        tau *= min(1, args.tau_alpha_crit / args.alpha)

    for otr, _velo, _grad, state, _converged in train_kernel(ktrtr, ytr, tau, args.train_time, args.alpha, partial(loss_func_prime, args), args.max_dgrad, args.max_dout):
        state['train'] = {
            'loss': loss_func(args, otr * ytr).mean().item(),
            'aloss': args.alpha * loss_func(args, otr * ytr).mean().item(),
            'err': (otr * ytr <= 0).double().mean().item(),
            'nd': (args.alpha * otr * ytr < 1).long().sum().item(),
            'dfnorm': otr.pow(2).mean().sqrt(),
            'outputs': otr if args.save_outputs else None,
            'labels': ytr if args.save_outputs else None,
        }

        print("[i={d[step]:d} t={d[t]:.2e} wall={d[wall]:.0f}] [dt={d[dt]:.1e} dgrad={d[dgrad]:.1e} dout={d[dout]:.1e}] [train aL={d[train][aloss]:.2e} err={d[train][err]:.2f} nd={d[train][nd]}]".format(d=state), flush=True)
        dynamics.append(state)

    c = torch.lstsq(otr.view(-1, 1), ktrtr).solution.flatten()

    if len(xte) > len(xtr):
        # from hessian import gradient
        a = gradient(f(xtr) @ c, f.parameters())
        ote = torch.stack([gradient(f(x[None]), f.parameters()) @ a for x in xte])
    else:
        ote = ktetr @ c

    out = {
        'dynamics': dynamics,
        'train': {
            'outputs': otr,
            'labels': ytr,
        },
        'test': {
            'outputs': ote,
            'labels': yte,
        },
        'kernel': {
            'train': {
                'value': ktrtr.cpu() if args.store_kernel == 1 else None,
                'diag': ktrtr.diag(),
                'mean': ktrtr.mean(),
                'std': ktrtr.std(),
                'norm': ktrtr.norm(),
            },
            'test': {
                'value': ktete.cpu() if args.store_kernel == 1 else None,
                'diag': ktete.diag(),
                'mean': ktete.mean(),
                'std': ktete.std(),
                'norm': ktete.norm(),
            },
        },
    }

    return out


def run_regular(args, f0, xtr, ytr, xte, yte):

    with torch.no_grad():
        otr0 = f0(xtr)
        ote0 = f0(xte)

    if args.f0 == 0:
        otr0 = torch.zeros_like(otr0)
        ote0 = torch.zeros_like(ote0)

    j = torch.randperm(min(len(xte), len(xtr)))[:10 * args.chunk]
    ytrj = ytr[j]
    ytej = yte[j]

    t = perf_counter()

    tau = args.tau_over_h * args.h
    if args.tau_alpha_crit is not None:
        tau *= min(1, args.tau_alpha_crit / args.alpha)

    dynamics = []
    for f, state, done in train_regular(f0, xtr, ytr, tau, args.train_time, args.alpha, partial(loss_func, args), bool(args.f0), args.max_dgrad, args.max_dout):
        with torch.no_grad():
            otr = f(xtr[j]) - otr0[j]
            ote = f(xte[j]) - ote0[j]

        if args.arch.split('_')[0] == 'fc':
            def getw(f, i):
                return torch.cat(list(getattr(f.f, "W{}".format(i))))
            state['wnorm'] = [getw(f, i).norm().item() for i in range(f.f.L + 1)]
            state['dwnorm'] = [(getw(f, i) - getw(f0, i)).norm().item() for i in range(f.f.L + 1)]

        state['train'] = {
            'loss': loss_func(args, otr * ytrj).mean().item(),
            'aloss': args.alpha * loss_func(args, otr * ytrj).mean().item(),
            'err': (otr * ytr[j] <= 0).double().mean().item(),
            'nd': (args.alpha * otr * ytr[j] < 1).long().sum().item(),
            'dfnorm': otr.pow(2).mean().sqrt(),
            'fnorm': (otr + otr0[j]).pow(2).mean().sqrt(),
            'outputs': otr if args.save_outputs else None,
            'labels': ytrj if args.save_outputs else None,
        }
        state['test'] = {
            'loss': loss_func(args, ote * ytej).mean().item(),
            'aloss': args.alpha * loss_func(args, ote * ytej).mean().item(),
            'err': (ote * yte[j] <= 0).double().mean().item(),
            'nd': (args.alpha * ote * yte[j] < 1).long().sum().item(),
            'dfnorm': ote.pow(2).mean().sqrt(),
            'fnorm': (ote + ote0[j]).pow(2).mean().sqrt(),
            'outputs': ote if args.save_outputs else None,
            'labels': ytej if args.save_outputs else None,
        }
        print("[i={d[step]:d} t={d[t]:.2e} wall={d[wall]:.0f}] [dt={d[dt]:.1e} dgrad={d[dgrad]:.1e} dout={d[dout]:.1e}] [train aL={d[train][aloss]:.2e} err={d[train][err]:.2f} nd={d[train][nd]}/{p}] [test aL={d[test][aloss]:.2e} err={d[test][err]:.2f}]".format(d=state, p=len(j)), flush=True)
        dynamics.append(state)

        if done or perf_counter() - t > 120:
            t = perf_counter()

            with torch.no_grad():
                otr = f(xtr) - otr0
                ote = f(xte) - ote0

            out = {
                'dynamics': dynamics,
                'train': {
                    'f0': otr0,
                    'outputs': otr,
                    'labels': ytr,
                },
                'test': {
                    'f0': ote0,
                    'outputs': ote,
                    'labels': yte,
                }
            }
            yield f, out


def run_exp(args, f0, xtr, ytr, xte, yte):
    run = {
        'args': args,
        'N': sum(p.numel() for p in f0.parameters()),
    }

    if args.delta_kernel == 1 or args.init_kernel == 1:
        init_kernel = compute_kernels(f0, xtr, xte[:len(xtr)])

    if args.init_kernel == 1:
        run['init_kernel'] = run_kernel(args, *init_kernel, f0, xtr, ytr, xte, yte)

    if args.delta_kernel == 1:
        init_kernel = (init_kernel[0].cpu(), init_kernel[2].cpu())
    elif args.init_kernel == 1:
        del init_kernel

    if args.regular == 1:
        for f, out in run_regular(args, f0, xtr, ytr, xte, yte):
            run['regular'] = out
            yield run

        if args.delta_kernel == 1 or args.final_kernel == 1:
            final_kernel = compute_kernels(f, xtr, xte[:len(xtr)])

        if args.final_kernel == 1:
            run['final_kernel'] = run_kernel(args, *final_kernel, f, xtr, ytr, xte, yte)

        if args.delta_kernel == 1:
            final_kernel = (final_kernel[0].cpu(), final_kernel[2].cpu())
            run['delta_kernel'] = {
                'train': (init_kernel[0] - final_kernel[0]).norm().item(),
                'test': (init_kernel[1] - final_kernel[1]).norm().item(),
            }

    yield run


def execute(args):
    torch.backends.cudnn.benchmark = True
    if args.dtype == 'float64':
        torch.set_default_dtype(torch.float64)
    if args.dtype == 'float32':
        torch.set_default_dtype(torch.float32)

    if args.d is None or args.d == 0:
        (xtr, ytr), (xte, yte) = get_binary_dataset(args.dataset, args.ptr, args.data_seed, args.device)
    else:
        (xtr, ytr), (xte, yte) = get_binary_pca_dataset(args.dataset, args.ptr, args.d, args.whitening, args.data_seed, args.device)

    xtr = xtr.type(torch.get_default_dtype())
    xte = xte.type(torch.get_default_dtype())
    ytr = ytr.type(torch.get_default_dtype())
    yte = yte.type(torch.get_default_dtype())

    assert len(xte) >= args.pte
    xte = xte[:args.pte]
    yte = yte[:args.pte]

    torch.manual_seed(args.init_seed + hash(args.alpha))

    arch, act = args.arch.split('_')
    if act == 'relu':
        act = lambda x: 2 ** 0.5 * torch.relu(x)
    elif act == 'tanh':
        act = torch.tanh
    elif act == 'softplus':
        factor = torch.nn.functional.softplus(torch.randn(100000, dtype=torch.float64), args.spbeta).pow(2).mean().rsqrt().item()
        act = lambda x: torch.nn.functional.softplus(x, beta=args.spbeta).mul(factor)
    else:
        raise ValueError('act not specified')

    if arch == 'fc':
        assert args.L is not None
        xtr = xtr.flatten(1)
        xte = xte.flatten(1)
        f = FC(xtr.size(1), args.h, args.L, act, args.bias).to(args.device)
    elif arch == 'cv':
        assert args.bias == 0
        f = CV(xtr.size(1), args.h, L1=args.cv_L1, L2=args.cv_L2, act=act, h_base=args.cv_h_base, fsz=args.cv_fsz, pad=args.cv_pad, stride_first=args.cv_stride_first).to(args.device)
    elif arch == 'resnet':
        assert args.bias == 0
        f = Wide_ResNet(xtr.size(1), 28, args.h, act, 1, args.mix_angle).to(args.device)
    else:
        raise ValueError('arch not specified')

    f = SplitEval(f, args.chunk)

    torch.manual_seed(args.batch_seed)
    for run in run_exp(args, f, xtr, ytr, xte, yte):
        yield run


def main():
    git = {
        'log': subprocess.getoutput('git log --format="%H" -n 1 -z'),
        'status': subprocess.getoutput('git status -z'),
    }

    # コマンドライン引数の代わりに直接変数を設定
    args = argparse.Namespace(
        device='cuda',
        dtype='float64',

        init_seed=0,
        data_seed=0,
        batch_seed=0,

        dataset='fashion',
        ptr=10000,
        pte=50000,
        d=None,
        whitening=1,

        arch='fc_softplus',
        bias=0,
        L=3,
        h=100,
        mix_angle=45,
        spbeta=5.0,
        cv_L1=2,
        cv_L2=2,
        cv_h_base=1,
        cv_fsz=5,
        cv_pad=1,
        cv_stride_first=1,

        init_kernel=0,
        regular=1,
        final_kernel=0,
        store_kernel=0,
        delta_kernel=0,
        save_outputs=0,

        alpha=1e-4,
        f0=1,

        tau_over_h=1e-3,
        tau_alpha_crit=1e3,

        train_time=18000,
        chunk=None,
        max_dgrad=1e-4,
        max_dout=1e-1,

        loss='softhinge',
        lossbeta=20.0,

        pickle='F10k3Lsp_alpha.pkl'
    )

    if args.pte is None:
        args.pte = args.ptr

    if args.chunk is None:
        args.chunk = args.ptr

    torch.save(args, args.pickle)
    try:
        for res in execute(args):
            res['git'] = git
            with open(args.pickle, 'wb') as f:
                torch.save(args, f)
                torch.save(res, f)
    except:
        os.remove(args.pickle)
        raise

if __name__ == "__main__":
    main()

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:04<00:00, 6059786.68it/s] 


Extracting /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 278743.85it/s]


Extracting /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5092269.69it/s]


Extracting /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 3948843.63it/s]


Extracting /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.torchvision/datasets/FashionMNIST/FashionMNIST/raw



RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

# main_adam

In [None]:
import copy
import itertools
import os
import subprocess
from time import perf_counter

import torch

# from archi import CV, FC, Wide_ResNet
# from dataset import get_binary_dataset, get_binary_pca_dataset
# from dynamics import loglinspace


class SplitEval(torch.nn.Module):
    def __init__(self, f, size):
        super().__init__()
        self.f = f
        self.size = size

    def forward(self, x):
        return torch.cat([self.f(x[i: i + self.size]) for i in range(0, len(x), self.size)])


def hinge(out, y, alpha):
    return (1 - alpha * out * y).relu().mean() / alpha


def quad_hinge(out, y, alpha):
    return 0.5 * (1 - alpha * out * y).relu().pow(2).mean() / alpha ** 2


def mse(out, y, alpha):
    return 0.5 * (1.1 - alpha * out * y).pow(2).mean() / alpha ** 2


def run_regular(args, f0, loss, xtr, ytr, xte, yte):

    with torch.no_grad():
        otr0 = f0(xtr)
        ote0 = f0(xte)

    f = copy.deepcopy(f0)
    optimizer = torch.optim.Adam(f.parameters(), args.lr)

    dynamics = []
    checkpoint_generator = loglinspace(0.1, 1000)
    checkpoint = next(checkpoint_generator)
    wall = perf_counter()

    for step in itertools.count():

        batch = torch.randperm(len(xtr))[:args.bs]
        xb = xtr[batch]

        loss_value = loss(f(xb) - otr0[batch], ytr[batch], args.alpha)

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

        save = False

        if step == checkpoint:
            checkpoint = next(checkpoint_generator)
            assert checkpoint > step
            save = True

        if save:
            assert len(xtr) < len(xte)
            j = torch.randperm(len(xtr))

            with torch.no_grad():
                otr = f(xtr[j]) - otr0[j]
                ote = f(xte[j]) - ote0[j]

            state = {
                'step': step,
                'wall': perf_counter() - wall,
                'batch_loss': loss_value.item(),
                'norm': sum(p.norm().pow(2) for p in f.parameters()).sqrt().item(),
                'dnorm': sum((p0 - p).norm().pow(2) for p0, p in zip(f0.parameters(), f.parameters())).sqrt().item(),
                'train': {
                    'loss': loss(otr, ytr[j], args.alpha).item(),
                    'aloss': args.alpha * loss(otr, ytr[j], args.alpha).item(),
                    'aaloss': args.alpha ** 2 * loss(otr, ytr[j], args.alpha).item(),
                    'err': (otr * ytr[j] <= 0).double().mean().item(),
                    'nd': (args.alpha * otr * ytr[j] < 1).long().sum().item(),
                    'dfnorm': otr.pow(2).mean().sqrt(),
                    'fnorm': (otr + otr0[j]).pow(2).mean().sqrt(),
                },
                'test': {
                    'loss': loss(ote, yte[j], args.alpha).item(),
                    'aloss': args.alpha * loss(ote, yte[j], args.alpha).item(),
                    'aaloss': args.alpha ** 2 * loss(ote, yte[j], args.alpha).item(),
                    'err': (ote * yte[j] <= 0).double().mean().item(),
                    'nd': (args.alpha * ote * yte[j] < 1).long().sum().item(),
                    'dfnorm': ote.pow(2).mean().sqrt(),
                    'fnorm': (ote + ote0[j]).pow(2).mean().sqrt(),
                },
            }

            if args.arch.split('_')[0] == 'fc':
                def getw(f, i):
                    return torch.cat(list(getattr(f.f, "W{}".format(i))))
                state['wnorm'] = [getw(f, i).norm().item() for i in range(f.f.L + 1)]
                state['dwnorm'] = [(getw(f, i) - getw(f0, i)).norm().item() for i in range(f.f.L + 1)]

            print("[i={d[step]:d} wall={d[wall]:.0f}] [train aL={d[train][aloss]:.2e} err={d[train][err]:.2f} nd={d[train][nd]}/{p}] [test aL={d[test][aloss]:.2e} err={d[test][err]:.2f}]".format(d=state, p=len(j)), flush=True)

            dynamics.append(state)

            if state['train']['nd'] == 0:
                break

        if perf_counter() > wall + args.train_time:
            break

    with torch.no_grad():
        otr = f(xtr) - otr0
        ote = f(xte) - ote0

    out = {
        'dynamics': dynamics,
        'train': {
            'f0': otr0,
            'outputs': otr,
            'labels': ytr,
        },
        'test': {
            'f0': ote0,
            'outputs': ote,
            'labels': yte,
        }
    }
    return f, out


def run_exp(args, f0, xtr, ytr, xte, yte):
    run = {
        'args': args,
        'N': sum(p.numel() for p in f0.parameters()),
    }

    if args.loss == 'hinge':
        loss = hinge
    if args.loss == 'quad_hinge':
        loss = quad_hinge
    if args.loss == 'mse':
        loss = mse

    _f, out = run_regular(args, f0, loss, xtr, ytr, xte, yte)
    run['regular'] = out

    return run


def execute(args):
    torch.backends.cudnn.benchmark = True
    if args.dtype == 'float64':
        torch.set_default_dtype(torch.float64)
    if args.dtype == 'float32':
        torch.set_default_dtype(torch.float32)

    if args.d is None or args.d == 0:
        (xtr, ytr), (xte, yte) = get_binary_dataset(args.dataset, args.ptr, args.data_seed, args.device)
    else:
        (xtr, ytr), (xte, yte) = get_binary_pca_dataset(args.dataset, args.ptr, args.d, args.whitening, args.data_seed, args.device)

    xtr = xtr.type(torch.get_default_dtype())
    xte = xte.type(torch.get_default_dtype())
    ytr = ytr.type(torch.get_default_dtype())
    yte = yte.type(torch.get_default_dtype())

    assert len(xte) >= args.pte
    xte = xte[:args.pte]
    yte = yte[:args.pte]

    torch.manual_seed(args.init_seed + hash(args.alpha))

    arch, act = args.arch.split('_')
    if act == 'relu':
        act = lambda x: 2 ** 0.5 * torch.relu(x)
    elif act == 'tanh':
        act = torch.tanh
    elif act == 'softplus':
        factor = torch.nn.functional.softplus(torch.randn(100000, dtype=torch.float64), args.spbeta).pow(2).mean().rsqrt().item()
        act = lambda x: torch.nn.functional.softplus(x, beta=args.spbeta).mul(factor)
    else:
        raise ValueError('act not specified')

    if arch == 'fc':
        assert args.L is not None
        xtr = xtr.flatten(1)
        xte = xte.flatten(1)
        f = FC(xtr.size(1), args.h, args.L, act).to(args.device)
    elif arch == 'cv':
        f = CV(xtr.size(1), args.h, h_base=1, L1=2, L2=2, act=act, fsz=5, pad=1, stride_first=True).to(args.device)
    elif arch == 'resnet':
        f = Wide_ResNet(xtr.size(1), 28, args.h, act, 1, args.mix_angle).to(args.device)
    else:
        raise ValueError('arch not specified')

    f = SplitEval(f, args.chunk)

    torch.manual_seed(args.batch_seed)
    run = run_exp(args, f, xtr, ytr, xte, yte)
    return run


# Move Args class outside of main function
class Args:
    arch = 'fc_softplus'
    alpha = 1e-3
    batch_seed = 0
    bs = 32
    d = 0
    dataset = 'cifar10'
    device = 'cuda'
    dtype = 'float32'
    h = 100
    init_seed = 0
    L = 3
    loss = 'hinge'
    mix_angle = 0.5
    pte = 50000
    ptr = 10000
    spbeta = 5.0
    train_time = 28800.0
    whitening = True
    data_seed = 0
    chunk = 128
    lr = 1e-3


def main():
    git = {
        'commit': subprocess.getoutput("git rev-parse HEAD"),
        'branch': subprocess.getoutput("git rev-parse --abbrev-ref HEAD"),
        'message': subprocess.getoutput("git log -1 --pretty=format:'%h %s'"),
    }

    # Grid search over `alpha` and `init_seed`
    alpha_values = [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7]
    init_seed_values = list(range(20))

    results = []

    for alpha, init_seed in itertools.product(alpha_values, init_seed_values):
        args = copy.deepcopy(Args)
        args.alpha = alpha
        args.init_seed = init_seed

        run = execute(args)

        result_file = f"results/experiment_{args.arch}_{args.alpha}_{args.loss}_{args.h}_seed{args.init_seed}.pt"
        os.makedirs("results", exist_ok=True)
        torch.save(run, result_file)
        results.append(result_file)

        print(f"Results saved to {result_file}")

    # Print Git information
    print("Git information:")
    for k, v in git.items():
        print(f"{k}: {v}")

if __name__ == "__main__":
    main()

[i=0 wall=0] [train aL=1.00e+00 err=0.47 nd=10000/10000] [test aL=1.00e+00 err=0.48]
[i=1 wall=0] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=2 wall=0] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=3 wall=1] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=4 wall=1] [train aL=1.00e+00 err=0.46 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=5 wall=1] [train aL=1.00e+00 err=0.46 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=6 wall=1] [train aL=1.00e+00 err=0.46 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=7 wall=1] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=8 wall=1] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=9 wall=1] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=10 wall=1] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test aL=1.00e+00 err=0.46]
[i=11 wall=2] [train aL=1.00e+00 err=0.45 nd=10000/10000] [test 

# Notation

The the code and in the article, the conventions differ.

change 1
- code: `loss = 1/alpha etc`
- article: `loss = 1/alpha^2 etc`

change 2
- code: `1/h` at the end of the network
- article: `1/sqrt(h)` at the end of the network

```
alpha_code = sqrt(h) alpha_article

t_code = sqrt(h) / alpha_article * t_article
t_article = alpha_code / h * t_code

t_code / h = t_article / (sqrt(h) alpha_article)
```

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
import glob
import functools
import pickle
import math
import numpy as np
import itertools
# from grid import load, print_info

####################################################################################################
from collections import defaultdict, namedtuple

Run = namedtuple("Run", "file, ctime, args, data")
GLOBALCACHE = defaultdict(dict)

def deepmap(fun, data):
    if isinstance(data, (list, tuple, set, frozenset)):
        return type(data)(deepmap(fun, x) for x in data)

    if isinstance(data, dict):
        return {key: deepmap(fun, x) for key, x in data.items()}

    return fun(data)

def torch_to_numpy(data):
    import torch

    def fun(x):
        if isinstance(x, torch.Tensor):
            return x.numpy()
        else:
            return x

    return deepmap(fun, data)

def identity(x):
    return x

def to_dict(a):
    if not isinstance(a, dict):
        return a.__dict__
    return a

def load_file(f):
    with open(f, "rb") as rb:
        yield to_dict(pickle.load(rb))
        yield pickle.load(rb)

def _load_iter(
    directory,
    pred_args=None,
    pred_run=None,
    cache=True,
    extractor=None,
    convertion=None,
    tqdm=identity,
):
    if extractor is not None:
        cache = False

    directory = os.path.normpath(directory)

    if not os.path.isdir(directory):
        raise NotADirectoryError("{} does not exists".format(directory))

    cache_runs = GLOBALCACHE[(directory, convertion)] if cache else dict()

    for file in tqdm(sorted(glob.glob(os.path.join(directory, "*.pk")))):
        ctime = os.path.getctime(file)

        if file in cache_runs and ctime == cache_runs[file].ctime:
            x = cache_runs[file]

            if pred_args is not None and not pred_args(x.args):
                continue

            if pred_run is not None and not pred_run(x.data):
                continue

            yield (x.args, x.data)
            continue

        try:
            f = load_file(file)
            args = next(f)

            if pred_args is not None and not pred_args(args):
                continue

            data = next(f)
        except (pickle.PickleError, FileNotFoundError, EOFError):
            continue

        if extractor is not None:
            data = extractor(data)

        if pred_run is not None and not pred_run(data):
            continue

        if convertion == "torch_to_numpy":
            data = torch_to_numpy(data)
        elif convertion == "args":
            data = args
        elif convertion == "file_args":
            data = (file, args)
        else:
            assert convertion is None

        x = Run(file=file, ctime=ctime, args=args, data=data)
        cache_runs[file] = x

        yield (x.args, x.data)

def load_iter(
    directory,
    pred_args=None,
    pred_run=None,
    cache=True,
    extractor=None,
    convertion=None,
    tqdm=identity,
    with_args=False,
):
    for d in directory.split(":"):
        for a, r in _load_iter(d, pred_args, pred_run, cache, extractor, convertion, tqdm):
            if with_args:
                yield a, r
            else:
                yield r

def load(
    directory,
    pred_args=None,
    pred_run=None,
    cache=True,
    extractor=None,
    convertion=None,
    tqdm=identity,
    with_args=False,
):
    return list(load_iter(directory, pred_args, pred_run, cache, extractor, convertion, tqdm=tqdm, with_args=with_args))
####################################################################################################

def mean(x):
    x = list(x)
    return sum(x) / len(x)

def median(x):
    x = sorted(list(x))
    return x[len(x) // 2]

def triangle(a, b, c, d=None, slope=None, other=False, color=None, fmt="{:.2f}", textpos=None):
    import math

    if slope is not None and d is None:
        d = math.exp(math.log(c) + slope * (math.log(b) - math.log(a)))
    if slope is not None and c is None:
        c = math.exp(math.log(d) - slope * (math.log(b) - math.log(a)))
    if color is None:
        color = 'k'

    plt.plot([a, b], [c, d], color=color)
    if other:
        plt.plot([a, b], [c, c], color=color)
        plt.plot([b, b], [c, d], color=color)
    else:
        plt.plot([a, b], [d, d], color=color)
        plt.plot([a, a], [c, d], color=color)

    s = (math.log(d) - math.log(c)) / (math.log(b) - math.log(a))
    if other:
        x = math.exp(0.7 * math.log(b) + 0.3 * math.log(a))
        y = math.exp(0.7 * math.log(c) + 0.3 * math.log(d))
    else:
        x = math.exp(0.7 * math.log(a) + 0.3 * math.log(b))
        y = math.exp(0.7 * math.log(d) + 0.3 * math.log(c))
    if textpos:
        x = textpos[0]
        y = textpos[1]
    plt.annotate(fmt.format(s), (x, y), horizontalalignment='center', verticalalignment='center')
    return s

def nd(x, a):
    assert not torch.isnan(x['outputs']).any()
    return (a * x['outputs'] * x['labels'] < 1).nonzero().numel()

def err(x):
    assert not torch.isnan(x['outputs']).any()
    return (x['outputs'] * x['labels'] <= 0).double().mean().item()

def enserr(xs):
    f = mean(x['outputs'] for x in xs)
    y = xs[0]['labels']
    assert all((x['labels'] == y).all() for x in xs)
    return (f * y <= 0).double().mean().item()

def var(outs, alpha):
    otr = alpha * torch.stack(outs)
    return otr.sub(otr.mean(0)).pow(2).mean(1).sum(0).item() / (otr.size(0) - 1)

def texnum(x, mfmt='{}'):
    m, e = "{:e}".format(x).split('e')
    m, e = float(m), int(e)
    mx = mfmt.format(m)
    if e == 0:
        if m == 1:
            return "1"
        return mx
    ex = "10^{{{}}}".format(e)
    if m == 1:
        return ex
    return "{}\;{}".format(mx, ex)

def logfilter(x, y, num):
    import numpy as np
    import scipy.ndimage
    x = np.array(x)
    y = np.array(y)
    x = np.log(x)
    xi = np.linspace(min(x), max(x), num)
    yi = np.interp(xi, x, y)
    yf = scipy.ndimage.filters.gaussian_filter1d(yi, 2)
    return np.exp(xi), yf

def yavg(xi, x, y):
    import numpy as np
    xi = np.array(xi)
    xmin = min(np.min(x) for x in x)
    xmax = min(np.max(x) for x in x)
    xi = xi[np.logical_and(xmin < xi, xi < xmax)]
    y = [np.interp(xi, np.array(x), np.array(y)) for x, y in zip(x, y)]
    y = np.mean(y, axis=0)
    return xi, y

import matplotlib.ticker as ticker

@ticker.FuncFormatter
def format_percent(x, pos=None):
    x = 100 * x
    if x % 1 > 0.05:
        return r"${:.1f}\%$".format(x)
    else:
        return r"${:.0f}\%$".format(x)

# Setup

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(2.3, 2), dpi=130)


plt.sca(ax1)
x = torch.linspace(-0.7, 1.5, 200)
plt.plot(x, x.neg().add(1).relu(), label='hinge')
plt.plot(x, torch.nn.functional.softplus(x.neg().add(1), beta=20), label=r'soft-hinge $\beta=20$')
plt.plot(x, torch.nn.functional.softplus(x.neg().add(1), beta=5), label=r'soft-hinge $\beta=5$')
plt.plot(x, torch.nn.functional.softplus(x.neg().add(1), beta=1), label=r'soft-hinge $\beta=1$')

plt.legend(handlelength=1, labelspacing=0, frameon=False)
plt.xlabel(r'$fy$')
plt.ylabel(r'$\ell(fy)$')
plt.xlim(min(x), max(x))

plt.tight_layout()
plt.savefig('loss.pgf')

# Disentangling feature learning versus lazy learning from performance