In [1]:
from dataclasses import replace
import numpy as np
from cifar10 import Classifier, ReLU, Softmax, Layer, BatchNorm
from cifar10.data import load_batch, make_normalizer, vector_to_image

from tqdm.auto import trange

In [2]:
data = load_batch("../data/data_batch_1")
data["features"] = data["features"][:, :20]
data.keys()

dict_keys(['features', 'labels'])

In [3]:
classifier = Classifier.from_dims(
    [20, 20, 20, 20, 10],
    make_hidden_layer=Classifier.layer_maker(ReLU, batch_norm=True),
    make_final_layer=Classifier.layer_maker(Softmax, batch_norm=False),
    normalize=make_normalizer(data["features"]),
)
analytic_gradients = classifier.gradient(
    input=data["features"],
    targets=data["labels"],
    regularization=0,
)

In [4]:
def centered_difference(param: str, h=1e-5):
    gradients = []
    for layer_idx, layer in enumerate(classifier.steps):
        layers = classifier.steps.copy()
        try:
            original = next(
                getattr(step, param)
                for step in layer.steps
                if hasattr(step, param)
            )
        except StopIteration:
            continue
        gradients.append(np.zeros(original.shape))
        for flat_idx in trange(np.prod(original.shape)):
            array_idx = np.unravel_index(flat_idx, original.shape)

            def compute_loss(diff):
                attempt = original.copy()
                attempt[array_idx] += diff
                layers[layer_idx] = replace(
                    layer,
                    steps=[
                        replace(step, **{param: attempt}) if hasattr(step, param) else step
                        for step in layer.steps
                    ],
                )
                replaced_classifier = replace(
                    classifier,
                    steps=layers,
                )
                outputs = replaced_classifier.forward(data["features"])
                return replaced_classifier.loss(outputs, data["labels"])

            gradients[layer_idx][array_idx] = (compute_loss(h) - compute_loss(-h)) / (2 * h)
    
    return gradients

numeric_gradients = {
    param: centered_difference(param)
    for param in ["weights", "shift", "scale"]
}

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [5]:
def relative_error(a, b, eps=1e-6):
    values = np.abs(a - b) / np.maximum(eps, np.abs(a) + np.abs(b))
    return dict(
        mean=np.mean(values),
        max=np.max(values),
    )

def param_errors(param: str):
    return [
        relative_error(numeric, analytic)
        for numeric, analytic in zip(
            numeric_gradients[param],
            (
                getattr(step, param)
                for layer in analytic_gradients[0].steps
                for step in layer.steps
                if hasattr(step, param)
            )
        )
    ]

param_errors("weights")

[{'mean': 2.3607450715512787e-05, 'max': 0.00195062644852372},
 {'mean': 1.9594892703145383e-05, 'max': 0.0049478174137972935},
 {'mean': 2.69610441865422e-06, 'max': 0.0006946453558802901},
 {'mean': 2.4045212816507504e-09, 'max': 2.0498676078445554e-07}]

In [6]:
param_errors("shift")

[{'mean': 2.429017294644777e-06, 'max': 1.6676546671737633e-05},
 {'mean': 7.375072037062794e-07, 'max': 1.4676646879081515e-05},
 {'mean': 4.899215024336978e-10, 'max': 3.15950496592603e-09}]

In [7]:
param_errors("scale")

[{'mean': 5.506953105013933e-07, 'max': 3.2055461856375996e-06},
 {'mean': 2.0105473889857947e-09, 'max': 1.250254822128786e-08},
 {'mean': 3.8945109063091744e-09, 'max': 7.065683330834366e-08}]