In [None]:
from typing import cast
from pathlib import Path
from functools import partial

import torch
from torch.nn.functional import softmax
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import seaborn as sns

from elasticai.creator.file_generation.on_disk_path import OnDiskPath
from elasticai.creator.vhdl.system_integrations.firmware_env5 import FirmwareENv5

from ballchallenge.model_builder import ModelBuilder
from ballchallenge.accelerometer_dataset import AccelerometerDataset
from ballchallenge.training import run_training


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET_ROOT = Path("../data")
GRID_SIZE = (10, 10)

In [None]:
def flat_labels(labels: torch.Tensor) -> torch.Tensor:
    return labels.flatten(start_dim=1)

def downsample(samples: torch.Tensor, factor: int) -> torch.Tensor:
    return samples[:,:,::factor]

ds = AccelerometerDataset(
    dataset_root=DATASET_ROOT,
    grid_size=GRID_SIZE,
    x_position_range=(0, 2),
    y_position_range=(0, 2),
    label_std=0.24,
    transform_samples=partial(downsample, factor=4),
    transform_labels=flat_labels,
)
ds_train, ds_test = random_split(ds, lengths=[0.75, 0.25])

print("Train Samples:", ds_train[:][0].shape)
print("Test Samples:", ds_test[:][0].shape)

In [None]:
input_shape = cast(tuple[int, int], tuple(ds_train[:][0].shape[1:]))

model_builder = ModelBuilder(total_bits=16, frac_bits=8, input_shape=input_shape)
model_builder.add_conv1d(filters=16, kernel_size=8).add_hard_tanh()
model_builder.add_conv1d(filters=8, kernel_size=16).add_hard_tanh()
model_builder.add_conv1d(filters=4, kernel_size=32).add_hard_tanh()
model_builder.add_conv1d(filters=2, kernel_size=64).add_hard_tanh()
model_builder.add_conv1d(filters=1, kernel_size=128).add_hard_tanh()
model_builder.add_flatten()
model_builder.add_linear(output_units=GRID_SIZE[0] * GRID_SIZE[1])

model = model_builder.build_model()

In [None]:
print("Trainable model parameters:", sum(param.numel() for param in model.parameters()))

In [None]:
history = run_training(
    model=model,
    ds_train=ds_train,
    ds_test=ds_test,
    batch_size=8,
    epochs=800,
    learning_rate=1e-4,
    device=DEVICE
)

model.eval()
model.to("cpu")

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

axs[0].plot(history.train["epoch"], history.train["loss"], label="train")
axs[0].plot(history.test["epoch"], history.test["loss"], label="test")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].legend()

axs[1].plot(history.train["epoch"], history.train["accuracy"], label="train")
axs[1].plot(history.test["epoch"], history.test["accuracy"], label="test")
axs[1].set_xlabel("Epochs")
axs[1].set_ylabel("Accuracy")
axs[1].legend()

In [None]:
def render_target_and_prediction(sample_idx = 0):
    sample, target = ds_train[sample_idx]
    prediction = softmax(model(sample).detach(), dim=1).view(*GRID_SIZE[::-1])
    target = target.view(*GRID_SIZE[::-1])
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

    sns.heatmap(prediction, cmap="hot", square=True, ax=axs[0])
    axs[0].set_title("Prediction")

    sns.heatmap(target, cmap="hot", square=True, ax=axs[1])
    axs[1].set_title("Target")

for i in range(4):
    render_target_and_prediction(sample_idx=i)

In [None]:
_, labels = ds[:]
mean_label = labels.mean(dim=0)
sns.heatmap(mean_label.view(GRID_SIZE), square=True, cmap="hot")

In [None]:
hw_design = model.create_design("ball_throw")

path = OnDiskPath("build")

channels, signal_length = ds_train[0][0].shape
total_length = channels * signal_length
firmware = FirmwareENv5(hw_design, x_num_values=total_length,
                        y_num_values=GRID_SIZE[0]*GRID_SIZE[1], skeleton_version="v2", id=666)

firmware.save_to(path)