In [1]:
import pathlib as pl

from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX
from controllers import pid

import snntorch as snn
from snntorch import spikeplot as splt
import torch
import torch.nn as nn
import numpy as np
import itertools

from matplotlib import pyplot as plt
import seaborn as sns
from IPython.display import HTML

from typing import Any, override

sns.set_theme()

In [2]:
def plot_rollout(sim):
  fig, ax = plt.subplots(figsize=(10, 5))
  ax.plot(sim.target_lataccel_history, label="Target Lateral Acceleration", alpha=0.5)
  ax.plot(sim.current_lataccel_history, label="Actual Lateral Acceleration", alpha=0.5)
  ax.legend()
  ax.set_xlabel("Step")
  ax.set_ylabel("Lateral Acceleration")
  ax.set_title("Rollout")
  plt.show()

In [8]:
from torch.utils.data import DataLoader, Dataset


class CustomImageDataset(Dataset):
    def __init__(self, segs):
        self.segs = list(segs)

    def __len__(self):
        return len(self.segs)

    def __getitem__(self, idx):
        return pl.Path(f'data/{idx:05d}')


training_data = CustomImageDataset(range(5000))
test_data = CustomImageDataset(range(5000))

train_loader = DataLoader(training_data, batch_size=48, shuffle=True)
test_loader = DataLoader(test_data, batch_size=48, shuffle=True)

In [5]:
# Network Architecture
num_inputs = 1 + 1 + 3
num_hidden = 1000
num_outputs = 1

# Temporal Dynamics
num_steps = 100
beta = 0.95

In [6]:
from snntorch import surrogate


# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # spike_grad = surrogate.fast_sigmoid()
        # default is atan()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    @override
    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        # time-loop
        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # store in list
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)  # time-steps x batch x num_out


dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# if you're on M1 or M2 GPU:
# device = torch.device("mps")

# Load the network onto CUDA if available
# device = torch.device("cpu")
net = Net().to(device)

In [9]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

num_epochs = 1 # 60000 / 128 = 468
counter = 0

# Outer training loop
for _epoch in range(num_epochs):
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        print(data)
        print(dir(data))
        breakpoint()
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, _ = net(data)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val = loss(spk_rec.sum(0), targets) # batch x num_out

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        if counter % 10 == 0:
            print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
          break

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.WindowsPath'>

In [None]:
model = TinyPhysicsModel("./models/tinyphysics.onnx", debug=True)
controller = pid.Controller()

In [None]:
sim = TinyPhysicsSimulator(model, "./data/00000.csv", controller=controller, debug=False)
sim.rollout()

In [None]:
plot_rollout(sim)