In [None]:

import torch as t
from torchvision import datasets, transforms

from perturb.experiment import Experiment
from perturb.interventions.weights import PerturbWeights
from perturb.observations.metrics import FullPanelMetrics
from perturb.observations.plots import FullPanelPlotter, Plotter
from perturb.utils import setup
from perturb.variables.models import FCN

device = setup()

In [None]:
def get_mnist_data():
    train_ds = datasets.MNIST(
        root="data", train=True, download=True, transform=transforms.ToTensor()
    )
    test_ds = datasets.MNIST(
        root="data", train=False, download=True, transform=transforms.ToTensor()
    )

    return train_ds, test_ds


def get_cifar10_data():
    train_ds = datasets.CIFAR10(
        root="data", train=True, download=True, transform=transforms.ToTensor()
    )
    test_ds = datasets.CIFAR10(
        root="data", train=False, download=True, transform=transforms.ToTensor()
    )

    return train_ds, test_ds


def get_imagenet_data():
    train_ds = datasets.ImageFolder(
        root="data/imagenet/train", transform=transforms.ToTensor()
    )
    test_ds = datasets.ImageFolder(
        root="data/imagenet/val", transform=transforms.ToTensor()
    )

    return train_ds, test_ds


def get_data(dataset: str):
    if dataset == "mnist":
        return get_mnist_data()
    elif dataset == "cifar10":
        return get_cifar10_data()
    elif dataset == "imagenet":
        return get_imagenet_data()
    else:
        raise ValueError(f"Unknown dataset {dataset}")


DEFAULT_MODEL_HYPERPARAMS = dict(
    n_hidden=100,
)

DEFAULT_SGD_HYPERPARAMS = dict(
    lr=0.01,
    momentum=0.0,
    weight_decay=0.0,
)

exp = Experiment(
    model=(FCN, DEFAULT_MODEL_HYPERPARAMS),
    opt=(t.optim.SGD, DEFAULT_SGD_HYPERPARAMS),
    datasets=get_data("mnist"),
    interventions=[
        PerturbWeights.make_variations(
            epsilon=[0.001, 0.01, .1],
            seed_perturbation=range(3),
        )
    ],
    plotter=FullPanelPlotter(
        average_over=["seed_perturbation"],
        dir="plots/vanilla"
    ),
    metrics=FullPanelMetrics(ivl=200),
    name="vanilla"
)

In [None]:
exp.run(n_epochs=5, n_epochs_at_a_time=2)

In [None]:
# df = df.assign(_d_w_from_baseline_normed=df["d_w_from_baseline_normed"])
# Multiply d_w_from_baseline_normed by the w column & divide by the initial w value (grouped by seed_perturbation) 

initial_ws = df.copy().loc[df["step"] == 0].groupby("seed_perturbation")["w"].first()

for row in df.itertuples():
    # print(row.Index, row.d_w_from_baseline_normed, row.w_normed)
    df.loc[row.Index]["d_w_from_baseline_normed"] /= row.w_normed

df["d_w_from_baseline_normed"].plot()
df

In [None]:
from scipy import stats
import scipy


epsilons = df["epsilon"].unique()
epsilons = np.sort(epsilons)[1:]

fig, axs = plt.subplots(len(epsilons), 2, figsize=(10, 20))

fig.tight_layout(pad=5.0)

def plot_fit(
    x: np.ndarray, 
    y: np.ndarray, 
    fit: Callable,
    start: int = 0, 
    end: int = -1, 
    ax: Optional[plt.Axes] = None, 
    color: str = "red",
    **kwargs
):
    if end < 0:
        end = len(x) + end + 1

    yhat = fit(x, y, start, end)

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    ax.plot(x, yhat, color=color, label=fit.__name__, **kwargs)

    if start != 0:
        ax.axvline(x[start], color=color, linestyle="--")
    if end != -1:
        ax.axvline(x[end], color=color, linestyle="--")

    return yhat
 
def plot_residues(
    x: np.ndarray,
    y: np.ndarray,
    yhat: np.ndarray,
    ax: Optional[plt.Axes] = None,
    color: str = "red",
    **kwargs
): 
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    ax.plot(x, y - yhat, color=color, **kwargs)



def fit_line(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    m, b, r_value, p_value, std_err = stats.linregress(x[start:end], y[start:end])
    return m * x + b

def fit_parabola(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    z = np.polyfit(x[start:end], y[start:end], 2)
    f = np.poly1d(z)
    return f(x)

def fit_log(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    return np.exp(fit_line(np.log(x), np.log(y), start=start, end=end))

def fit_sqrt(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    return fit_line(np.sqrt(x), np.sqrt(y), start=start, end=end) ** 2.

def fit_exp_decaying_slope(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    return fit_line(x, y, start=start, end=end) * np.exp(-x)


def fit_spherical_rw(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    # cos(theta) = np.exp(-D * (n-1) * x / (r ** 2))
    # y = 2 * r * np.sqrt((1- cos(theta)) / 2)

    def f(x, D_tilde, r):
        # D_tilde = D * (n-1) 
        return 2 * r * np.sqrt((1 - np.exp(-D_tilde * x / r ** 2)) / 2)
    
    popt, pcov = scipy.optimize.curve_fit(f, x[start:end], y[start:end], p0=(0.000001, 5))

    return f(x, *popt)

for epsilon, (ax, ax2)  in zip(epsilons, axs):
    data = df.loc[df["epsilon"] == epsilon]

    ax.set_title(f"$\\epsilon = {epsilon}$")
    average = data.groupby("step").mean(numeric_only=True)
    average.reset_index(inplace=True)

    for seed in data["seed_perturbation"].unique():
        data_seed = data.loc[data["seed_perturbation"] == seed]
        ax.plot(data_seed["step"], data_seed["d_w_from_baseline_normed"], label=f"_$\epsilon = {epsilon}$", alpha=0.1)

    # Plot averages
    averages = average["d_w_from_baseline_normed"]
    # Windowed average (10)
    averages = np.convolve(averages, np.ones((10,))/10, mode='valid')
    # Pad
    averages = np.pad(averages, (0, len(average) - len(averages)), mode="edge")

    ax.plot(average["step"], averages, label=f"$\epsilon = {epsilon}$", color="red")

    # Fit a line
    line_fit = plot_fit(average["step"], averages, fit_line, ax=ax, color="green", alpha=0.5, start=50, end=-50)
    plot_residues(average["step"], averages, line_fit, ax=ax2, color="green")

    # Fit a parabola
    parabola_fit = plot_fit(average["step"], averages, fit_parabola, ax=ax, color="orange", alpha=0.5, start=50, end=-50)
    plot_residues(average["step"], averages, parabola_fit, ax=ax2, color="orange")

    # Fit a square root
    sqrt_fit = plot_fit(average["step"], averages, fit_sqrt, ax=ax, color="blue", alpha=0.5, start=50, end=-50)
    plot_residues(average["step"], averages, sqrt_fit, ax=ax2, color="blue")

    # Fit a log
    log_fit = plot_fit(average["step"], averages, fit_log, ax=ax, color="purple", alpha=0.5, start=50, end=-50)
    plot_residues(average["step"], averages, log_fit, ax=ax2, color="purple")

    spherical_rw_fit = plot_fit(average["step"], averages, fit_spherical_rw, ax=ax, color="black", alpha=0.5, start=50, end=-50)
    plot_residues(average["step"], averages, spherical_rw_fit, ax=ax2, color="black", alpha=0.5)

    # ax.set_yscale("log")
    ax.set_ylim(1., averages.max())
    ax.legend()

    # ax2.set_ylim(-0.2, 0.2)


plt.show()

In [None]:
from functools import partial
from scipy import stats
import scipy


epsilons = df["epsilon"].unique()
epsilons = np.sort(epsilons)[1:]

fig, axs = plt.subplots(len(epsilons), 2, figsize=(10, 20))

fig.tight_layout(pad=5.0)


# def fit_spherical_rw(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1, r=1., n=2) -> np.ndarray:
#     # cos(theta) = np.exp(-D * (n-1) * x / (r ** 2))
#     # y = 2 * r * np.sqrt((1- cos(theta)) / 2)

#     def f(x, D):
#         return 2 * r * np.sqrt((1 - np.exp(-D * (n-1) * x / r ** 2)) / 2) + 1
    
#     popt, pcov = scipy.optimize.curve_fit(f, x[start:end], y[start:end], p0=(10e-18))

#     return f(x, *popt)

def fit_fractional_power(x: np.ndarray, y: np.ndarray, start: int = 0, end: int = -1) -> np.ndarray:
    # f(x) = a x^b + c

    def f(x, a, b, c):
        return a * (x ** b) + 1
        # return a + b * x + c * x ** .5
    
    popt, pcov = scipy.optimize.curve_fit(f, x[start:end], y[start:end])
    print(popt)

    return f(x, *popt)



for epsilon, (ax, ax2)  in zip(epsilons, axs):
    data = df.loc[df["epsilon"] == epsilon]

    ax.set_title(f"$\\epsilon = {epsilon}$")
    average = data.groupby("step").mean(numeric_only=True)
    average.reset_index(inplace=True)

    for seed in data["seed_perturbation"].unique():
        data_seed = data.loc[data["seed_perturbation"] == seed]
        ax.plot(data_seed["step"], data_seed["d_w_from_baseline_normed"], label=f"_$\epsilon = {epsilon}$", alpha=0.1)

    # Plot averages
    averages = average["d_w_from_baseline_normed"]
    # Windowed average (10)
    averages = np.convolve(averages, np.ones((10,))/10, mode='valid')
    # Pad
    averages = np.pad(averages, (0, len(average) - len(averages)), mode="edge")

    ax.plot(average["step"], averages, label=f"$\epsilon = {epsilon}$", color="red")


    n_params = (784 + 1) * 100 + (100 + 1) * 10 + (10 + 1) * 10
    # fitter = partial(fit_spherical_rw, r=data["w"].to_numpy()[-1], n=n_params)
    # fitter.__name__ = "fit_spherical_rw"
    
    # spherical_rw_fit = plot_fit(average["step"], averages, fitter, ax=ax, color="black", alpha=0.5, start=50, end=-50)
    # plot_residues(average["step"], averages, spherical_rw_fit, ax=ax2, color="black", alpha=0.5)

    fractional_power_fit = plot_fit(average["step"], averages, fit_fractional_power, ax=ax, color="black", alpha=0.5, start=0, end=-200)
    plot_residues(average["step"], averages, fractional_power_fit, ax=ax2, color="black", alpha=0.5)

    # ax.set_yscale("log")
    ax.set_ylim(1., averages.max())
    ax.legend()

    # ax2.set_ylim(-0.2, 0.2)


plt.show()

In [None]:
# Plot f_n(x) = (sin(x)) ** n for n = 1, 2, 3, 4, 5 over the interval [0, pi]

theta = np.linspace(0, np.pi, 1000)

def f(x, n): 
    return np.sin(x) ** n

plt.plot(theta, f(theta, 1), label="n = 1")
plt.plot(theta, f(theta, 2), label="n = 2")
plt.plot(theta, f(theta, 3), label="n = 3")
plt.plot(theta, f(theta, 4), label="n = 4")
plt.plot(theta, f(theta, 5), label="n = 5")

plt.legend()

