In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import polars as pl
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

from src.cci.dataset import RandomSample, ToTensor, TransitionDataset

In [None]:
# Split training / validation data

from pathlib import Path

from src.cci.dataset import CropSample, create_cache

clean_df = pl.read_csv("data/clean_df.csv").with_row_index()
(labels,) = clean_df.select("Class Label")
labels = labels.to_numpy()

train_idx, test_idx = train_test_split(
    range(len(clean_df)),
    stratify=labels,
    test_size=0.1,
)
train_idx, val_idx = train_test_split(
    train_idx,
    stratify=labels[train_idx],
    test_size=0.1,
)
train_df = clean_df.filter(pl.col("index").is_in(train_idx))
test_df = clean_df.filter(pl.col("index").is_in(test_idx))
val_df = clean_df.filter(pl.col("index").is_in(val_idx))
print(len(train_df))
print(len(test_df))
print(len(val_df))

print(train_df.head())
print(test_df.head())
print(val_df.head())

sample_length = 1500

if "signal_cache" not in globals():
    signal_cache = create_cache(clean_df, Path(os.environ["OOCHA_DIR"]))

ds_train = TransitionDataset(
    train_df,
    os.environ["OOCHA_DIR"],
    transforms=[
        RandomSample(sample_length),
        ToTensor(),
    ],
    cache=signal_cache,
)
ds_test = TransitionDataset(
    test_df,
    os.environ["OOCHA_DIR"],
    transforms=[
        CropSample(sample_length),
        ToTensor(),
    ],
    cache=signal_cache,
)
ds_val = TransitionDataset(
    val_df,
    os.environ["OOCHA_DIR"],
    transforms=[
        CropSample(sample_length),
        ToTensor(),
    ],
    cache=signal_cache,
)

for _ in range(3):
    sample = ds_train.__getitem__(2)
    signal, label = sample["signal"], sample["label"]
    plt.plot(signal)
plt.title("Training sample augmented by shifting")
plt.show()

In [None]:
import torch
from torch import nn, optim, tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader

from src.cci.metrics import Metrics
from src.cci.models import MLPModel

model = MLPModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Running with device: {torch.cuda.get_device_name(0)}")

train_loader = DataLoader(
    ds_train,
    batch_size=16,
    shuffle=True,
)
val_loader = DataLoader(
    ds_val,
    batch_size=32,
    shuffle=True,
)
test_loader = DataLoader(
    ds_test,
    batch_size=32,
    shuffle=True,
)

train_metrics = Metrics("train", len(ds_train), device)
val_metrics = Metrics("val", len(ds_val), device)
test_metrics = Metrics("test", len(ds_test), device)

opt = optim.Adam(model.parameters())
# Unbalanced dataset => weighted loss fn
loss_fn = nn.BCEWithLogitsLoss(
    pos_weight=tensor(ds_train.get_pos_weight()),
)
val_loss_fn = nn.BCEWithLogitsLoss()

model.to(device)

In [None]:
from aim import Run
from tqdm.notebook import tqdm

run = Run(experiment="MLP")

run["hparams"] = {
    "learning_rate": 0.002,
    "batch_size": 32,
}

epochs = 1000
plot_cm = True
for epoch in tqdm(range(1, epochs + 1), "Epochs"):
    train_metrics.reset()
    model.train()
    for data in train_loader:
        sample, label = data["signal"].to(device), data["label"].to(device)
        opt.zero_grad()
        logits = model(sample)

        loss = loss_fn(logits, label.float())
        loss.backward()
        opt.step()

        predictions = F.sigmoid(logits)
        train_metrics.update(predictions, label, loss)
    train_metrics.upload_metrics_epoch(run, epoch, plot_cm)

    val_metrics.reset()
    model.eval()
    with torch.no_grad():
        for data in val_loader:
            sample, label = data["signal"].to(device), data["label"].to(device)
            logits = model(sample)

            loss = val_loss_fn(logits, label.float())

            predictions = F.sigmoid(logits)
            val_metrics.update(predictions, label, loss)
    val_metrics.upload_metrics_epoch(run, epoch, plot_cm)

    plot_cm = True if epoch % 10 == 0 else False

# Run test set
model.eval()
with torch.no_grad():
    for data in tqdm(test_loader, "Run on test set"):
        sample, label = data["signal"].to(device), data["label"].to(device)
        logits = model(sample)

        loss = val_loss_fn(logits, label.float())

        predictions = F.sigmoid(logits)
        test_metrics.update(predictions, label, loss)

train_metrics.upload_training_end(run)
val_metrics.upload_training_end(run)
test_metrics.upload_test(run)
run.close()