In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from src.twisted_attr import TwistedAttractor
from src.utils import MultiAttractor
from src.utils import get_path

## Torus attractors

In [None]:
m = 3
mec = MultiAttractor(m=m, dt=0.4)
W = mec.attr[0].W
print(W.shape)

In [None]:
y0 = torch.zeros(m, 20, 20)
y0[:,0,0] = 10.
y0 = y0.view(m, 1, 400)

mec.reset(y0)

for t in range(1, 300):
        b = 0.2*torch.ones(m,1,400)
        y0 = mec(torch.zeros(m,1,2), b)
            

plt.imshow(y0.view(m*20,20))

## Path Data

In [None]:
w=40
paths = [get_path(T=500, w=w, vmax=0.5) for _ in range(20)]

## Path visualization

In [None]:
from src.plotting import scatter

X,_ = paths[1]
scatter(X, mode="lines")

## path integration

In [None]:
A = torch.randn(m, 3, 2)
A = A/torch.norm(A, dim=1, keepdim=True)
A[0] = torch.eye(3)[:,:2]

In [None]:
model_input = []

for x, v in paths:
    model_input.append(4*v@A)

In [None]:
model_output = []

for i,v_ in enumerate(model_input):
    T = v_.shape[1]
    mec.reset(y0)
    
    if i%10 == 0: print(i, T)
        
    y = torch.zeros(T, m*400)
    y[[0]] = y0.clone().view(1,m*400)
    
    for t in range(1, T):
        b = .2*torch.ones(m,1,400)
        
        vel = v_[:,[t],:] + 0.3*torch.randn_like(v_[:,[t],:])
        y[[t]] = mec(vel, b)
        
    model_output.append(y.clone())

In [None]:
x = torch.cat([x for x,v in paths], dim=0)
print(f"x: {x.size()}")

y = torch.cat(model_output, dim=0)
y = y.view(-1, m,20,20)
print(f"y: {y.size()}")



## Look

In [None]:
from scipy.stats import binned_statistic_dd as binst
from scipy.ndimage.filters import gaussian_filter
from src.plotting import plot_3d_iso
from scipy.io import savemat

In [None]:
print(y.min(), y.max())
plt.imshow(y[10].view(m*20,20))

In [None]:
bins = 100

vals = y[:,0,0,0].numpy() + \
       y[:,1,0,0].numpy() + \
       y[:,2,0,0].numpy()  

hist, be,_ = binst(x[:].numpy(), vals, bins=bins, statistic='mean')
hist = np.nan_to_num(hist)
print(hist.min(), hist.max())

In [None]:
values = gaussian_filter(hist, 3)
print(values.min(),values.max())

In [None]:
X, Y, Z = np.mgrid[0:bins:1, 0:bins:1, 0:bins:1]
eps = .01
c = 0.001

plot_3d_iso(X,Y,Z, values, c, eps)

In [None]:
scipy.io.savemat('./data/data.mat', mdict={'arr': values, 'thresh': c, 'b': float(bins)})