In [None]:
import torch
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import seaborn as sns

from ballchallenge.model import BallChallengeModel
from ballchallenge.dummy_dataset import DummyDataset
from ballchallenge.training import run_training

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
ds = DummyDataset()
ds_train, ds_test = random_split(ds, lengths=[0.75, 0.25])

In [None]:
model = BallChallengeModel(total_bits=16, frac_bits=8, signal_length=1000, impact_grid_size=(10, 10))

history = run_training(
    model=model,
    ds_train=ds_train,
    ds_test=ds_test,
    batch_size=128,
    epochs=100,
    learning_rate=1e-3,
    device=DEVICE
)

model.eval()
model.to("cpu")

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
ax.plot(history.train["epoch"], history.train["loss"], label="train")
ax.plot(history.test["epoch"], history.test["loss"], label="test")
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
ax.plot(history.train["epoch"], history.train["accuracy"], label="train")
ax.plot(history.test["epoch"], history.test["accuracy"], label="test")
ax.set_xlabel("Epochs")
ax.set_ylabel("Accuracy")

In [None]:
sample_idx = 0

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
sns.heatmap(model(ds_test[sample_idx][0]).detach(), cmap="hot", ax=axs[0])
sns.heatmap(ds_test[sample_idx][1], cmap="hot", ax=axs[1])