In [None]:
import torch
from torch.nn.functional import softmax
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import seaborn as sns

from ballchallenge.model import BallChallengeModel
from ballchallenge.dummy_dataset import DummyDataset
from ballchallenge.training import run_training

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = (40, 40)

In [None]:
model = BallChallengeModel(total_bits=16, frac_bits=8, signal_length=1000, grid_size=GRID_SIZE)
sum(param.numel() for param in model.parameters())

In [None]:
ds = DummyDataset(grid_size=GRID_SIZE, label_std=4)
ds_train, ds_test = random_split(ds, lengths=[0.75, 0.25])

In [None]:
model = BallChallengeModel(total_bits=16, frac_bits=8, signal_length=1000, grid_size=GRID_SIZE)

history = run_training(
    model=model,
    ds_train=ds_train,
    ds_test=ds_test,
    batch_size=1024,
    epochs=30,
    learning_rate=1e-3,
    device=DEVICE
)

model.eval()
model.to("cpu")

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

axs[0].plot(history.train["epoch"], history.train["loss"], label="train")
axs[0].plot(history.test["epoch"], history.test["loss"], label="test")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].legend()

axs[1].plot(history.train["epoch"], history.train["accuracy"], label="train")
axs[1].plot(history.test["epoch"], history.test["accuracy"], label="test")
axs[1].set_xlabel("Epochs")
axs[1].set_ylabel("Accuracy")
axs[1].legend()

In [None]:
sample_idx = 2

sample, target = ds_test[sample_idx]
prediction = softmax(model(sample).detach(), dim=0).view(*GRID_SIZE[::-1])
target = target.view(*GRID_SIZE[::-1])

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

sns.heatmap(prediction, cmap="hot", square=True, ax=axs[0])
axs[0].set_title("Prediction")

sns.heatmap(target, cmap="hot", square=True, ax=axs[1])
axs[1].set_title("Target")