In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import scipy
import torch
from tqdm import tqdm

In [2]:
from models import FFGC, RNNGC
from dataset import DatasetMaker

dataset = DatasetMaker()

train_steps = 50000

ng = 128
bs = 256 # batch size 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model = FFGC(ng = ng, alpha=0.9)
model = RNNGC(ng = ng, alpha=0.9)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

loss_history = []
progress = tqdm(range(train_steps))
for i in progress: # train loop
    # r = torch.rand((bs, 2), device = device)*4*np.pi - 2*np.pi
    # loss = model.train_step(inputs = r, labels = r, optimizer = optimizer)
    
    r, v = dataset.generate_data(bs, 5, device)
    loss = model.train_step(inputs = (r[:,0], v), labels = r, optimizer = optimizer)

    if i % 10 == 0:
        loss_history.append(loss)
        progress.set_description(f"loss: {loss:>7f}")

loss: -0.003607:   4%|▎         | 1817/50000 [00:45<19:52, 40.40it/s][E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:109] Exception in thread pool task: mutex lock failed: Invalid argument
loss: -0.003607:   4%|▎         | 1817/50000 [00:45<19:59, 40.18it/s]


KeyboardInterrupt: 

In [None]:
# evaluate on nxn grid
model.to('cpu')
n = 32

In [None]:
# x = np.linspace(-1, 1, n)*2*np.pi
# y = x.copy()
# xx, yy = np.meshgrid(x,y)
# u = torch.tensor(np.stack([xx.ravel(), yy.ravel()], axis = -1), dtype= torch.float32)
# p = model(u).detach().numpy()
# p.shape # 1024, 128

In [None]:
r, v = dataset.generate_data(10000, 5, device)
g = model((r[:,0], v))

r = r.detach().cpu().numpy()
g = g.detach().cpu().numpy()
p = scipy.stats.binned_statistic_2d(r[...,0].ravel(), r[...,1].ravel(), g.reshape(-1, g.shape[-1]).T, bins = 32)[0]
p = p.reshape(ng, -1).T
p.shape # 1024, 128

In [None]:
n_p = 10
fig, ax = plt.subplots(n_p, n_p, figsize =(12, 12))

for i, representation in enumerate(p.T[:n_p**2]):

    row = i // n_p
    col = i % n_p
    ax[row, col].axis("off")

    representation = representation.reshape(n, n)

    ax[row, col].imshow(representation, cmap = "jet", interpolation = "none")

plt.subplots_adjust(wspace=0.05, hspace=0.05)

In [None]:
# assume p is a torch tensor
p0 = torch.tensor(p.astype("float32").T)
w0 = torch.nn.Parameter((torch.rand((ng, ng), dtype=torch.float32) * 2 - 1)*0.001)
# 

# create a torch optimizer
optimizer = torch.optim.Adam([w0], lr=1e-4)
relu = torch.nn.ReLU()
losses = []
steps = 200000
# define a training loop
for _ in tqdm(range(steps)):
    optimizer.zero_grad()
    
    # z = Wp  # rotated version of population vector
    z = w0@torch.tensor(p0)
    a = w0.T@w0 - torch.eye(len(w0)) # be orthogonal
    b = w0@w0.T - torch.eye(len(w0)) # be orthogonal
    c = (torch.linalg.det(w0) - 1) # proper rotation
    d = relu(-z) # non-negative result everywhere
    
    loss = torch.mean(a**2) + torch.mean(b**2) + torch.mean(c**2) + torch.mean(d) 
    losses.append(loss.item())
    loss.backward()
    optimizer.step()

In [None]:
plt.semilogy(losses)

In [None]:
# def inf_rotate(v0, J, theta, n):
#     I = np.eye(len(v0))
#     R = I + theta*J
#     v = v0.copy()
#     for i in range(n):
#         v = R@v # infinitesimal rotation
#     return v

# def random_skew_symmetric_matrix(n):
#     J = np.random.choice([0, 1], (n, n))
#     for i in range(n):
#         for j in range(n):
#             if i == j:
#                 J[i,j] = 0
#             elif j < i:
#                 J[i,j] = -J[j,i]
#     return J

# j0 = random_skew_symmetric_matrix(len(p.T))
# z = inf_rotate(p.T, j0, 1e-5, 50000)

In [None]:
w = w0.detach().numpy()

In [None]:
f"Determinant: {np.linalg.det(w)}"

In [None]:
plt.imshow(w@w.T) # orthogonality?
plt.colorbar()

In [None]:
z = (w0@torch.tensor(p.T)).detach().numpy() # rotate population by W

In [None]:
n_p = 10
fig, ax = plt.subplots(n_p, n_p, figsize =(12, 12))

for i, representation in enumerate(p.T[:n_p**2]):

    row = i // n_p
    col = i % n_p
    ax[row, col].axis("off")

    representation = representation.reshape(n, n)

    ax[row, col].imshow(representation, interpolation = "none")

plt.suptitle("Before Rotation")
plt.subplots_adjust(wspace=0.05, hspace=0.05)

n_p = 10
fig, ax = plt.subplots(n_p, n_p, figsize =(10, 10))

for i, representation in enumerate(z[:n_p**2]):

    row = i // n_p
    col = i % n_p
    ax[row, col].axis("off")

    representation = representation.reshape(n, n)
    ax[row, col].imshow(representation, interpolation = "none")

plt.suptitle("After Rotation")
plt.subplots_adjust(wspace=0.05, hspace=0.05)

In [None]:
ps = z[:,512+16]
sim = np.exp(-np.sum((ps[None] - z.T)**2, axis = -1))
plt.imshow(sim.reshape(32,32), interpolation = "None")
plt.colorbar()

In [None]:
fig, ax = plt.subplots(20, 2, figsize = (2, 10))

for i in range(20):
    # ax[i,1].imshow(z[i].reshape(32,32),vmax = np.amax(p[:,i]))
    ax[i,1].imshow(z[i].reshape(32,32),vmax = np.amax(p[:,i]))

    ax[i, 0].imshow(p[:,i].reshape(32,32))
    ax[i,0].axis("off")
    ax[i,1].axis("off")

In [None]:
n_p = 10
fig, ax = plt.subplots(n_p, n_p, figsize =(10, 10))

for i, representation in enumerate(((p.T - z)**2)[:n_p**2]):

    row = i // n_p
    col = i % n_p
    ax[row, col].axis("off")

    representation = representation.reshape(n, n)

    ax[row, col].imshow(representation, interpolation = "none")


In [None]:
plt.imshow(w)
plt.colorbar()