In [1]:
import os, sys
from tqdm import trange
import tqdm
from IPython.utils import io

import scipy
import math
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import plotly.graph_objects as go

import torch
from torch import nn

source = "/home/loek/projects/rnn/source"
sys.path.append(source)

from data import fun_data, grid_data
from preprocessing import Direct
from compilation import Compiler, Tracker, ScalarTracker, ActivationTracker
from data_analysis.automata import to_automaton_history
from data_analysis.visualization.animation import SliderAnimation
from data_analysis.visualization.activations import (
    ActivationsAnimation,
    FunctionAnimation,
)
from data_analysis.visualization.automata import AutomatonAnimation
from data_analysis.visualization.epochs import EpochAnimation

from model import Model

import cProfile
import pstats


is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("GPU available")
else:
    device = torch.device("cpu")
    print("GPU not available")

device = torch.device("cpu")

GPU available


In [2]:
bounds = [-3, 3]
X_vol = bounds[1] - bounds[0]

diff = 1
gain = 1

N_min = 2
N_max = 30
n_N = 20

compl_max = 30
# freq_max = (2 / X_vol) * compl_max
freq_max = np.sqrt((2 / (X_vol * np.exp(0.5))) * compl_max)
n_freq = 20

In [3]:
## Grid search
Ns = np.linspace(N_min, N_max, n_N, dtype=int)
freqs = np.linspace(0, freq_max, n_freq)
train_loss, val_loss = np.zeros(shape=(n_N, n_freq)), np.zeros(shape=(n_N, n_freq))
i = 1
for k, N in enumerate(Ns):
    for l, freq in enumerate(freqs):
        ## Generate data
        # function = lambda x: diff * np.sin(freq * x) + 1
        sigma = 1 / (freq + 0.0000001)
        function = (
            lambda x: 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-(x**2) / (2 * sigma))
        )

        train_datasets = [
            fun_data(device, function, bounds=bounds, n_datapoints=N),
        ]
        val_dataset = [fun_data(device, function, bounds=bounds, n_datapoints=10000)]
        tracked_datasets = val_dataset + train_datasets
        encoding = Direct()

        ## Instantiate model
        model = Model(
            encoding=encoding,
            input_size=1,
            output_size=1,
            hidden_dim=50,
            n_hid_layers=30,
            device=device,
            init_std=gain,
        )

        ## Setup compiler
        n_epochs = 3000
        lr = 0.0005

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        compiler = Compiler(model, criterion, optimizer)
        compiler.trackers = {}

        ## Training run
        with io.capture_output() as captured:
            compiler.training_run(
                train_datasets, tracked_datasets, n_epochs=n_epochs, batch_size=1000
            )

        train_loss[k, l] = compiler.validation(train_datasets)[0][0]
        val_loss[k, l] = compiler.validation(val_dataset)[0][0]

        print(
            f"{i} out of {n_N*n_freq}, train_loss: {train_loss[k, l]}, val_loss:{val_loss[k, l]}"
        )
        i += 1

In [None]:
# Plot
# complexity = X_vol * np.sqrt(0.5 * freqs + freqs**2)

# complexity = X_vol * 0.5 * freqs
complexity = X_vol * 0.5 * freqs**2 * np.sqrt(3) * np.exp(0.5)
rc = np.random.rand(100, 100)

x = complexity
y = Ns
z = val_loss
z = scipy.ndimage.filters.gaussian_filter(z, 0.5, mode="constant")

N = 50
h = np.outer(complexity, np.ones(len(complexity)))
v = (np.max(val_loss) / np.max(h)) * h.copy().T

fig = go.Figure(
    data=[
        go.Surface(
            x=x,
            y=y,
            z=z,
        ),
        go.Surface(x=h, y=h, z=v, surfacecolor=np.ones(shape=(N, N)), opacity=0.5),
    ]
)

fig.update_layout(
    title="Validation loss",
    autosize=False,
    width=600,
    height=600,
    margin=dict(l=65, r=50, b=65, t=90),
)
fig.update_scenes(
    xaxis_title_text="Problem complexity",
    yaxis_title_text="Number of datapoints",
    zaxis_title_text="Loss",
)

fig.show()

In [None]:
# fig.write_html("plots/" + "validation_loss.html")