In [1]:
import os, sys
from tqdm import trange, tqdm
from IPython.utils import io
import itertools

import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy
from numba import njit

import torch
from torch import nn
from torch.utils.data import TensorDataset, ConcatDataset
import wandb

source = "../source"
sys.path.append(source)

from data import fun_data, grid_data
from preprocessing import Direct, Encoding, OneHot
from compilation import Compiler, Tracker, ScalarTracker, ActivationTracker
from activations import get_activations
from data_analysis.automata import to_automaton_history
from data_analysis.visualization.animation import SliderAnimation
from data_analysis.visualization.activations import (
    ActivationsAnimation,
    FunctionAnimation,
    PointAnimation,
)
from data_analysis.visualization.automata import AutomatonAnimation
from data_analysis.visualization.epochs import EpochAnimation
import data_analysis.visualization.publication as publication
import simulate
import two_points

import models as models
from models import MLP, CNN, ResNet

is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("GPU available")
else:
    device = torch.device("cpu")
    print("GPU not available")

device = torch.device("cpu")

GPU available


In [2]:
## Load settings
settings = "default"

(
    model_type,
    nonlinearity,
    gain,
    lr,
    P,
    L,
    n_epochs,
    hidden_layer,
    dx2,
    dy2,
    in_dim,
    out_dim,
) = (
    pd.read_csv("model_settings/2 points.txt", sep=" ", header=0)
    .loc[settings]
    .to_numpy()
)
model_type = getattr(models, model_type)
if nonlinearity == "discontinuous":
    nonlinearity = simulate.Discontinuous.apply
elif nonlinearity == "none":
    nonlinearity = None
else:
    nonlinearity = getattr(torch.nn.functional, nonlinearity)

mod = 1
# factor = 0.25
# n_epochs = int(factor * n_epochs)
# lr = lr / factor

In [3]:
## Generate data

input_dim, output_dim = 1, 1

# data, encoding = two_points.data_set(dx2, dy2, input_dim, output_dim, device)

inputs = np.array([[-1] * input_dim, [-1 + np.sqrt(dx2)] * input_dim]) / np.sqrt(
    input_dim
)
outputs = np.array([[0.6] * output_dim, [0.6 + np.sqrt(dy2)] * output_dim]) / np.sqrt(
    output_dim
)
names = ["A", "B"]
data = TensorDataset(
    torch.from_numpy(inputs.astype(np.float32)).to(device),
    torch.from_numpy(outputs.astype(np.float32)).to(device),
)

encoding = Encoding(dict(zip(names, inputs)))

train_datasets = [data]
val_dataset = [data]

tracked_datasets = val_dataset + train_datasets

In [4]:
losses, Ps, Ls = [], [], []

for L in trange(1, 30):
    P = L * 50
    hidden_layer = int(L / 2)

    ## Instantiate model
    model = model_type(
        encoding=encoding,
        input_size=input_dim,
        output_size=output_dim,
        hidden_dim=P,
        n_hid_layers=L,
        device=device,
        init_std=gain,
        non_linearity=nonlinearity,
    )
    ## Setup compiler

    # Define Loss, Optimizer
    criterion = lambda x, y: 0.5 * nn.functional.mse_loss(x, y)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    compiler = Compiler(model, criterion, optimizer)
    compiler.trackers = {
        "loss": ScalarTracker(lambda: compiler.validation(tracked_datasets)),
        "hidden": ActivationTracker(
            model,
            lambda inputs: model(inputs)[1][hidden_layer],
            datasets=tracked_datasets,
        ),
        "output": ActivationTracker(
            model, lambda inputs: model(inputs)[0], datasets=tracked_datasets
        ),
    }

    ## Training run
    with io.capture_output() as captured_output:
        compiler.training_run(
            train_datasets,
            tracked_datasets,
            n_epochs=n_epochs,
            batch_size=100,
        )

    data_hid = compiler.trackers["hidden"].get_trace()
    data_output = compiler.trackers["output"].get_trace()
    query = f"Epoch % {mod} == 0"
    data_hid = data_hid.copy().query(query)
    data_output = data_output.copy().query(query)
    h_A = [
        np.array(data.loc[epoch, 0, "A"])
        for epoch, data in data_hid.query("Dataset == 0").groupby("Epoch")
    ]
    h_B = [
        np.array(data.loc[epoch, 0, "B"])
        for epoch, data in data_hid.query("Dataset == 0").groupby("Epoch")
    ]
    y_A = [
        np.array(data.loc[epoch, 0, "A"])
        for epoch, data in data_output.query("Dataset == 0").groupby("Epoch")
    ]
    y_B = [
        np.array(data.loc[epoch, 0, "B"])
        for epoch, data in data_output.query("Dataset == 0").groupby("Epoch")
    ]

    epochs = np.arange(0, len(h_A))

    y_true_A, y_true_B = outputs[0], outputs[1]
    dy2 = np.sum((y_true_B - y_true_A) ** 2)
    h2 = np.array([np.sum((h_A[epoch] - h_B[epoch]) ** 2) for epoch in epochs])
    y2 = np.array([np.sum((y_A[epoch] - y_B[epoch]) ** 2) for epoch in epochs])
    w = np.array(
        [
            y2[epoch] - np.dot(y_true_A - y_true_B, y_A[epoch] - y_B[epoch])
            for epoch in epochs
        ]
    )
    y0_mean = np.sum((0.5 * ((y_A[0] + y_B[0]) - (y_true_B + y_true_A))) ** 2)

    h0, y0, w0, dy = h2[0], y2[0], w[0], dy2
    epochs = epochs * mod

    ## Fit effective learning rates
    eta_h_opt, eta_y_opt, loss = simulate.optimize_eta(
        h2, y2, w, dx2, dy2, guesses=np.logspace(-6, 2, 100)
    )

    losses.append(loss)
    Ps.append(P)
    Ls.append(L)

Training: 100%|██████████| 3000/3000 [00:21<00:00, 137.63steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.7920787930488586
y0: 0.026708265766501427
w0: 0.14226834947500516
dy: 0.5000000000000001


  3%|▎         | 1/29 [00:33<15:34, 33.37s/it]

Loss: 0.06347797557158782


Training: 100%|██████████| 3000/3000 [00:22<00:00, 131.19steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.5775344371795654
y0: 0.00019407583749853075
w0: -0.009656706867422305
dy: 0.5000000000000001


  7%|▋         | 2/29 [01:02<13:53, 30.89s/it]

Loss: 0.32555788169532984


Training: 100%|██████████| 3000/3000 [00:23<00:00, 126.80steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.5677255988121033
y0: 2.8383142307575326e-06
w0: -0.0011884454870160753
dy: 0.5000000000000001


 10%|█         | 3/29 [01:37<14:06, 32.57s/it]

Loss: 0.19156300928327835


Training: 100%|██████████| 3000/3000 [00:34<00:00, 87.89steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.29200783371925354
y0: 0.00013796288112644106
w0: 0.00844346958097681
dy: 0.5000000000000001


 14%|█▍        | 4/29 [02:14<14:22, 34.52s/it]

Loss: 0.27177577802851566


Training: 100%|██████████| 3000/3000 [00:36<00:00, 81.88steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.19714519381523132
y0: 0.001283206045627594
w0: 0.026613090581572407
dy: 0.5000000000000001


 17%|█▋        | 5/29 [02:55<14:39, 36.66s/it]

Loss: 0.10115269760430964


Training: 100%|██████████| 3000/3000 [00:39<00:00, 76.78steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.08959214389324188
y0: 4.094744144822471e-05
w0: -0.004483842067183193
dy: 0.5000000000000001


 21%|██        | 6/29 [03:37<14:48, 38.64s/it]

Loss: 0.02480841126502093


Training: 100%|██████████| 3000/3000 [00:41<00:00, 71.43steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.12686318159103394
y0: 3.446510527282953e-05
w0: -0.00411674593341466
dy: 0.5000000000000001


 24%|██▍       | 7/29 [04:29<15:44, 42.91s/it]

Loss: 0.033095321820426425


Training: 100%|██████████| 3000/3000 [00:45<00:00, 66.41steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.0793415904045105
y0: 0.0002558072446845472
w0: 0.011565255498074979
dy: 0.5000000000000001


 28%|██▊       | 8/29 [05:24<16:22, 46.77s/it]

Loss: 0.008776110904909304


Training: 100%|██████████| 3000/3000 [00:48<00:00, 62.14steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.06489614397287369
y0: 1.0538107744650915e-05
w0: 0.002305980107196076
dy: 0.5000000000000001


 31%|███       | 9/29 [06:16<16:08, 48.40s/it]

Loss: 0.024784144416706615


Training: 100%|██████████| 3000/3000 [00:52<00:00, 57.34steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.02042507939040661
y0: 1.5304981388908345e-06
w0: -0.0008732546723017961
dy: 0.5000000000000001


 34%|███▍      | 10/29 [07:13<16:09, 51.04s/it]

Loss: 0.0014232748498424536


Training: 100%|██████████| 3000/3000 [00:56<00:00, 53.05steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.027690060436725616
y0: 2.3453421249541861e-07
w0: -0.0003422083218252964
dy: 0.5000000000000001


 38%|███▊      | 11/29 [08:13<16:11, 53.99s/it]

Loss: 0.05172962685890511


Training: 100%|██████████| 3000/3000 [01:07<00:00, 44.71steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.01497587002813816
y0: 4.117758933830373e-08
w0: 0.00014352913552791508
dy: 0.5000000000000001


 41%|████▏     | 12/29 [09:26<16:55, 59.76s/it]

Loss: 0.00810414952046273


Training: 100%|██████████| 3000/3000 [01:14<00:00, 40.08steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.017987292259931564
y0: 2.3570739813294495e-06
w0: -0.001083247373573421
dy: 0.5000000000000001


 45%|████▍     | 13/29 [10:45<17:27, 65.49s/it]

Loss: 0.03877467649921658


Training: 100%|██████████| 3000/3000 [01:28<00:00, 33.81steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.005935323424637318
y0: 3.644300079486129e-07
w0: -0.00042650206211486256
dy: 0.5000000000000001


 48%|████▊     | 14/29 [12:19<18:31, 74.11s/it]

Loss: 0.0010843084464671566


Training: 100%|██████████| 3000/3000 [01:49<00:00, 27.28steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.006096958182752132
y0: 8.070663426451574e-08
w0: 0.00020096205321448773
dy: 0.5000000000000001


 52%|█████▏    | 15/29 [14:13<20:06, 86.16s/it]

Loss: 0.024369327483640724


Training: 100%|██████████| 3000/3000 [01:59<00:00, 25.13steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.004306002985686064
y0: 1.3486150862718205e-08
w0: -8.210274374503842e-05
dy: 0.5000000000000001


 55%|█████▌    | 16/29 [16:17<21:09, 97.64s/it]

Loss: 0.032511608851780296


Training: 100%|██████████| 3000/3000 [02:12<00:00, 22.69steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.0040100268088281155
y0: 1.277218419915016e-08
w0: -7.990025590121406e-05
dy: 0.5000000000000001


 59%|█████▊    | 17/29 [18:33<21:49, 109.16s/it]

Loss: 0.03695050569645375


Training: 100%|██████████| 3000/3000 [02:33<00:00, 19.55steps/s, train_loss=0.00000, val_loss=0.00000]


h0: 0.002510880120098591
y0: 1.4511121193550025e-08
w0: 8.519408941795053e-05
dy: 0.5000000000000001


 62%|██████▏   | 18/29 [21:13<22:48, 124.44s/it]

Loss: 0.05178590617640152


Training:  19%|█▉        | 567/3000 [00:35<02:32, 15.90steps/s, train_loss=0.04579, val_loss=0.04579]
 62%|██████▏   | 18/29 [21:49<13:20, 72.76s/it] 


KeyboardInterrupt: 