In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import stats
from scipy.spatial.distance import pdist, cdist
from scipy.spatial import cKDTree
import seaborn as sns
import pandas as pd
import umap
from sklearn.manifold import MDS
from scipy.spatial.distance import squareform, pdist

from tqdm import tqdm
import os

from src.models import SpaceNet, RecurrentSpaceNet, Decoder
from src.utils import ratemap_collage, SimpleDatasetMaker

plt.rcdefaults()
plt.style.use("figures/project_style.mplstyle")
%matplotlib inline

In [2]:
figure_path = os.path.join(os.getcwd(), "figures")
model_path = os.path.join(os.getcwd(), "models")
results_path = os.path.join(os.getcwd(), "results")

# Recurrent Network (no context)

In [3]:
# ----------------------- Params -----------------------
train_steps = 60000         # Number of training steps
timesteps = 10              # Length of trajectories
bs = 64                     # Batch size
lr = 1e-4                   # Learning rate
n_models = 1                # Number of models to train

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {
    # Default model
    "256": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.25, beta=0.5, device=device) for _ in range(n_models)],
    
    # Beta grid
    "256_0beta": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.25, beta=0., device=device) for _ in range(n_models)],
    "256_025beta": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.25, beta=0.25, device=device) for _ in range(n_models)],
    "256_075beta": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.25, beta=0.75, device=device) for _ in range(n_models)],
    
    # Scale grid
    "256_01scale": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.1, beta=0.5, device=device) for _ in range(n_models)],
    "256_05scale": [RecurrentSpaceNet(n_in=2, n_out=256, corr_across_space=True, scale=0.5, beta=0.5, device=device) for _ in range(n_models)],
    
    # n grid
    "512": [RecurrentSpaceNet(n_in=2, n_out=512, corr_across_space=True, scale=0.25, beta=0.5, device=device) for _ in range(n_models)],
    "1024": [RecurrentSpaceNet(n_in=2, n_out=1024, corr_across_space=True, scale=0.25, beta=0.5, device=device) for _ in range(n_models)],
}

loss_histories = {name: [] for name in models.keys()}

# --------------------- Training ----------------------

for name, model_list in models.items():
    
    print(f"Training {name}")
    for i, model in enumerate(model_list):
        
        print(f"Model {i+1}")
        
        # Initialize optimizer and dataset generator
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        genny = SimpleDatasetMaker()    
        
        if os.path.exists(os.path.join(model_path, f"{name}_{i}.pt")):
            model = torch.load(os.path.join(model_path, f"{name}_{i}.pt"))
            loss_history = np.load(os.path.join(model_path, f"{name}_{i}_loss_history.npy"))
            loss_histories[name].append(loss_history)
            continue
        
        loss_history = []
        progress = tqdm(range(train_steps))
        for k in progress:  
            
            # Create batch of trajectories
            r, v = genny.generate_dataset(bs, timesteps, device=device)
        
            # Perform training step
            loss = model.train_step(x=(v, r[:, 0]), y=r[:, 1:], optimizer=optimizer)
        
            loss_history.append(loss)
            
            if k % 10 == 0:
                progress.set_description(f"loss: {loss:>7f}")
                
        models[name][i] = model
        loss_histories[name].append(loss_history)

        # Save model and loss history
        torch.save(model, os.path.join(model_path, f"{name}_{i}.pt"))
        np.save(os.path.join(model_path, f"{name}_{i}_loss_history.npy"), loss_history)


Training 256
Model 1
Training 256_0beta
Model 1
Training 256_025beta
Model 1
Training 256_075beta
Model 1
Training 256_01scale
Model 1
Training 256_05scale
Model 1
Training 512
Model 1
Training 1024
Model 1


# Recurrent Network (context)

In [9]:
# ----------------------- Params -----------------------
train_steps = 60000             # Number of training steps
timesteps = 10                  # Length of trajectories
bs = 64                         # Batch size
lr = 1e-4                       # Learning rate
n_models = 1                    # Number of models to train
cmin = -2                       # Minimum context value
cmax = 2                        # Maximum context value

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {
    "256_context": [RecurrentSpaceNet(n_in=3, n_out=256, corr_across_space=True, scale=0.25, beta=0.5, device=device, initial_state_size=3) for _ in range(n_models)],
    "256_context_not_initial": [RecurrentSpaceNet(n_in=3, n_out=256, corr_across_space=True, scale=0.25, beta=0.5, device=device, initial_state_size=2) for _ in range(n_models)],
    
    # Beta grid
    "256_context_0beta": [RecurrentSpaceNet(n_in=3, n_out=256, corr_across_space=True, scale=0.25, beta=0., device=device, initial_state_size=3) for _ in range(n_models)],
    "256_context_025beta": [RecurrentSpaceNet(n_in=3, n_out=256, corr_across_space=True, scale=0.25, beta=0.25, device=device, initial_state_size=3) for _ in range(n_models)],
    "256_context_075beta": [RecurrentSpaceNet(n_in=3, n_out=256, corr_across_space=True, scale=0.25, beta=0.75, device=device, initial_state_size=3) for _ in range(n_models)],
}

loss_histories = {name: [] for name in models.keys()}

# --------------------- Training ----------------------

for name, model_list in models.items():
    
    print(f"Training {name}")
    for i, model in enumerate(model_list):
        
        print(f"Model {i+1}")
        
        context_in_initial = True if model.initial_state_size == 3 else False
        
        # Initialize optimizer and dataset generator
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        genny = SimpleDatasetMaker()    
        
        if os.path.exists(os.path.join(model_path, f"{name}_{i}.pt")):
            model = torch.load(os.path.join(model_path, f"{name}_{i}.pt"))
            loss_history = np.load(os.path.join(model_path, f"{name}_{i}_loss_history.npy"))
            loss_histories[name].append(loss_history)
            continue
        
        loss_history = []
        progress = tqdm(range(train_steps))
        for k in progress:  
            
            # Create batch of trajectories
            r, v = genny.generate_dataset(bs, timesteps, device=device)
            
            # Get random contexts and use for all timesteps along a trajectory
            c = torch.tensor(np.random.uniform(cmin, cmax, bs), dtype=torch.float32, device=device)
            c = c[:, None, None] * torch.ones((1, timesteps - 1, 1), device=device)
            
            # Build initial input
            if context_in_initial:
                initial_input = torch.cat((r[:, 0], c[:, 0]), dim=-1)
            else:
                initial_input = r[:, 0]
            
            # Concatenate velocity and context
            inputs = (torch.cat((v, c), dim=-1), initial_input)
            labels = (r[:, 1:], c)
        
            # Perform training step
            loss = model.train_step(x=inputs, y=labels, optimizer=optimizer)
        
            loss_history.append(loss)
            
            if k % 10 == 0:
                progress.set_description(f"loss: {loss:>7f}")
                
        models[name][i] = model
        loss_histories[name].append(loss_history)

        # Save model and loss history
        torch.save(model, os.path.join(model_path, f"{name}_{i}.pt"))
        np.save(os.path.join(model_path, f"{name}_{i}_loss_history.npy"), loss_history)


Training 256_context
Model 1


loss: 0.000619: 100%|██████████| 60000/60000 [16:38<00:00, 60.09it/s]


Training 256_context_not_initial
Model 1


loss: 0.002729:   6%|▋         | 3797/60000 [01:33<22:59, 40.74it/s]


KeyboardInterrupt: 

# Feedforward Network

In [7]:
# ----------------------- Params -----------------------
train_steps = 60000             # Number of training steps
bs = 64                         # Batch size
lr = 1e-4                       # Learning rate
n_models = 10                   # Number of models to train

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {
    "256_ff": [SpaceNet(n_in=2, n_out=256, scale=0.25, beta=0.5, device=device) for _ in range(n_models)],
    "256_ff_01scale": [SpaceNet(n_in=2, n_out=256, scale=0.1, beta=0.5, device=device) for _ in range(n_models)],
    "256_ff_05scale": [SpaceNet(n_in=2, n_out=256, scale=0.5, beta=0.5, device=device) for _ in range(n_models)],
    "256_ff_0beta": [SpaceNet(n_in=2, n_out=256, scale=0.25, beta=0, device=device) for _ in range(n_models)],
    "256_ff_025beta": [SpaceNet(n_in=2, n_out=256, scale=0.25, beta=0.25, device=device) for _ in range(n_models)],
    "256_ff_075beta": [SpaceNet(n_in=2, n_out=256, scale=0.25, beta=0.75, device=device) for _ in range(n_models)],
}

loss_histories = {name: [] for name in models.keys()}

# --------------------- Training ----------------------

for name, model_list in models.items():
    
    print(f"Training {name}")
    for i, model in enumerate(model_list):
        
        print(f"Model {i+1}")
        
        # Initialize optimizer and dataset generator
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        genny = SimpleDatasetMaker()    
        
        if os.path.exists(os.path.join(model_path, f"{name}_{i}.pt")):
            model = torch.load(os.path.join(model_path, f"{name}_{i}.pt"))
            loss_history = np.load(os.path.join(model_path, f"{name}_{i}_loss_history.npy"))
            loss_histories[name].append(loss_history)
            continue
        
        loss_history = []
        progress = tqdm(range(train_steps))
        for k in progress:  
            
            # Create batch of positions
            r = torch.tensor(np.random.uniform(-1, 1, (bs, 2)), dtype=torch.float32)
        
            # Perform training step
            loss = model.train_step(x=r, y=r, optimizer=optimizer)
        
            loss_history.append(loss)
            
            if k % 10 == 0:
                progress.set_description(f"loss: {loss:>7f}")
                
        models[name][i] = model
        loss_histories[name].append(loss_history)

        # Save model and loss history
        torch.save(model, os.path.join(model_path, f"{name}_{i}.pt"))
        np.save(os.path.join(model_path, f"{name}_{i}_loss_history.npy"), loss_history)


Training 256_ff
Model 1
Model 2


loss: 0.000151: 100%|██████████| 60000/60000 [00:49<00:00, 1205.75it/s]


Model 3


loss: 0.000151: 100%|██████████| 60000/60000 [00:48<00:00, 1225.13it/s]


Model 4


loss: 0.000161: 100%|██████████| 60000/60000 [00:39<00:00, 1503.19it/s]


Model 5


loss: 0.000152: 100%|██████████| 60000/60000 [00:43<00:00, 1371.91it/s]


Model 6


loss: 0.000150: 100%|██████████| 60000/60000 [00:47<00:00, 1255.16it/s]


Model 7


loss: 0.000147: 100%|██████████| 60000/60000 [00:44<00:00, 1337.79it/s]


Model 8


loss: 0.000149: 100%|██████████| 60000/60000 [00:50<00:00, 1198.57it/s]


Model 9


loss: 0.000159: 100%|██████████| 60000/60000 [01:10<00:00, 848.62it/s]


Model 10


loss: 0.000145: 100%|██████████| 60000/60000 [01:10<00:00, 849.71it/s]


Training 256_ff_01scale
Model 1


loss: 0.000502: 100%|██████████| 60000/60000 [01:04<00:00, 929.69it/s] 


Model 2


loss: 0.000441: 100%|██████████| 60000/60000 [01:02<00:00, 955.90it/s] 


Model 3


loss: 0.000404: 100%|██████████| 60000/60000 [00:57<00:00, 1035.64it/s]


Model 4


loss: 0.000503: 100%|██████████| 60000/60000 [01:00<00:00, 985.50it/s] 


Model 5


loss: 0.000378: 100%|██████████| 60000/60000 [01:12<00:00, 831.39it/s]


Model 6


loss: 0.000435: 100%|██████████| 60000/60000 [00:55<00:00, 1079.03it/s]


Model 7


loss: 0.000423: 100%|██████████| 60000/60000 [01:02<00:00, 952.66it/s] 


Model 8


loss: 0.000505: 100%|██████████| 60000/60000 [01:05<00:00, 914.27it/s] 


Model 9


loss: 0.000406: 100%|██████████| 60000/60000 [01:08<00:00, 874.33it/s] 


Model 10


loss: 0.000378: 100%|██████████| 60000/60000 [01:09<00:00, 859.14it/s] 


Training 256_ff_05scale
Model 1


loss: 0.000135: 100%|██████████| 60000/60000 [01:11<00:00, 838.33it/s] 


Model 2


loss: 0.000135: 100%|██████████| 60000/60000 [01:05<00:00, 917.56it/s] 


Model 3


loss: 0.000137: 100%|██████████| 60000/60000 [01:10<00:00, 856.08it/s] 


Model 4


loss: 0.000135: 100%|██████████| 60000/60000 [01:14<00:00, 807.78it/s] 


Model 5


loss: 0.000137: 100%|██████████| 60000/60000 [00:59<00:00, 1016.65it/s]


Model 6


loss: 0.000136: 100%|██████████| 60000/60000 [00:52<00:00, 1142.15it/s]


Model 7


loss: 0.000135: 100%|██████████| 60000/60000 [01:00<00:00, 985.36it/s] 


Model 8


loss: 0.000135: 100%|██████████| 60000/60000 [00:51<00:00, 1157.09it/s]


Model 9


loss: 0.000136: 100%|██████████| 60000/60000 [01:04<00:00, 931.15it/s] 


Model 10


loss: 0.000139: 100%|██████████| 60000/60000 [00:50<00:00, 1188.28it/s]


Training 256_ff_0beta
Model 1


loss: 0.001199: 100%|██████████| 60000/60000 [00:55<00:00, 1088.22it/s]


Model 2


loss: 0.001189: 100%|██████████| 60000/60000 [01:26<00:00, 689.75it/s] 


Model 3


loss: 0.001178: 100%|██████████| 60000/60000 [01:36<00:00, 623.74it/s] 


Model 4


loss: 0.001194: 100%|██████████| 60000/60000 [01:11<00:00, 843.70it/s] 


Model 5


loss: 0.001174: 100%|██████████| 60000/60000 [01:20<00:00, 741.48it/s] 


Model 6


loss: 0.001167: 100%|██████████| 60000/60000 [00:58<00:00, 1017.65it/s]


Model 7


loss: 0.001154: 100%|██████████| 60000/60000 [00:49<00:00, 1215.50it/s]


Model 8


loss: 0.001176: 100%|██████████| 60000/60000 [00:57<00:00, 1050.73it/s]


Model 9


loss: 0.001166: 100%|██████████| 60000/60000 [00:53<00:00, 1127.82it/s]


Model 10


loss: 0.001172: 100%|██████████| 60000/60000 [01:01<00:00, 978.89it/s] 


Training 256_ff_025beta
Model 1


loss: 0.000305: 100%|██████████| 60000/60000 [00:46<00:00, 1299.84it/s]


Model 2


loss: 0.000306: 100%|██████████| 60000/60000 [01:07<00:00, 891.22it/s] 


Model 3


loss: 0.000304: 100%|██████████| 60000/60000 [00:50<00:00, 1179.83it/s]


Model 4


loss: 0.000311: 100%|██████████| 60000/60000 [01:03<00:00, 943.51it/s] 


Model 5


loss: 0.000314: 100%|██████████| 60000/60000 [01:00<00:00, 999.40it/s] 


Model 6


loss: 0.000307: 100%|██████████| 60000/60000 [01:16<00:00, 786.78it/s] 


Model 7


loss: 0.000305: 100%|██████████| 60000/60000 [01:04<00:00, 924.71it/s] 


Model 8


loss: 0.000306: 100%|██████████| 60000/60000 [01:08<00:00, 877.89it/s] 


Model 9


loss: 0.000300: 100%|██████████| 60000/60000 [01:03<00:00, 939.49it/s] 


Model 10


loss: 0.000303: 100%|██████████| 60000/60000 [00:48<00:00, 1237.04it/s]


Training 256_ff_075beta
Model 1


loss: 0.000064: 100%|██████████| 60000/60000 [00:47<00:00, 1260.15it/s]


Model 2


loss: 0.000062: 100%|██████████| 60000/60000 [01:14<00:00, 801.31it/s] 


Model 3


loss: 0.000064: 100%|██████████| 60000/60000 [01:15<00:00, 799.24it/s] 


Model 4


loss: 0.000060: 100%|██████████| 60000/60000 [00:55<00:00, 1089.11it/s]


Model 5


loss: 0.000064: 100%|██████████| 60000/60000 [01:07<00:00, 893.59it/s] 


Model 6


loss: 0.000062: 100%|██████████| 60000/60000 [00:50<00:00, 1198.67it/s]


Model 7


loss: 0.000062: 100%|██████████| 60000/60000 [00:53<00:00, 1125.27it/s]


Model 8


loss: 0.000061: 100%|██████████| 60000/60000 [00:57<00:00, 1042.01it/s]


Model 9


loss: 0.000061: 100%|██████████| 60000/60000 [00:55<00:00, 1089.57it/s]


Model 10


loss: 0.000063: 100%|██████████| 60000/60000 [00:54<00:00, 1100.95it/s]


## Feedforward Network (context)

In [3]:
# ----------------------- Params -----------------------
train_steps = 60000             # Number of training steps
bs = 64                         # Batch size
lr = 1e-4                       # Learning rate
n_models = 1                    # Number of models to train
cmin = -2                       # Minimum context value
cmax = 2                        # Maximum context value

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {
    "256_ff_context": [SpaceNet(n_in=3, n_out=256, scale=0.25, beta=0.5, device=device) for _ in range(n_models)],
}

loss_histories = {name: [] for name in models.keys()}

# --------------------- Training ----------------------

for name, model_list in models.items():
    
    print(f"Training {name}")
    for i, model in enumerate(model_list):
        
        print(f"Model {i+1}")
        
        # Initialize optimizer and dataset generator
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        genny = SimpleDatasetMaker()    
        
        if os.path.exists(os.path.join(model_path, f"{name}_{i}.pt")):
            model = torch.load(os.path.join(model_path, f"{name}_{i}.pt"))
            loss_history = np.load(os.path.join(model_path, f"{name}_{i}_loss_history.npy"))
            loss_histories[name].append(loss_history)
            continue
        
        loss_history = []
        progress = tqdm(range(train_steps))
        for k in progress:  
            
            # Create batch of positions and contexts            
            r = torch.tensor(np.random.uniform(-1, 1, (bs, 2)), dtype=torch.float32)
            c = torch.tensor(np.random.uniform(cmin, cmax, bs), dtype=torch.float32)[:, None]
            inputs = torch.cat((r, c), dim=-1)
            labels = (r, c)
                
            # Perform training step
            loss = model.train_step(x=inputs, y=labels, optimizer=optimizer)
        
            loss_history.append(loss)
            
            if k % 10 == 0:
                progress.set_description(f"loss: {loss:>7f}")
                
        models[name][i] = model
        loss_histories[name].append(loss_history)

        # Save model and loss history
        torch.save(model, os.path.join(model_path, f"{name}_{i}.pt"))
        np.save(os.path.join(model_path, f"{name}_{i}_loss_history.npy"), loss_history)


Training 256_ff_context
Model 1


loss: 0.000514: 100%|██████████| 60000/60000 [01:11<00:00, 834.45it/s] 
