# Training & Inference Playground

This notebook orchestrates the existing CLI scripts for `LocationCNN`. Follow the cells below to train the model, persist metrics, visualize the curves, and run inference without leaving this workspace.

### Prerequisites

- Activate the repository environment (see `requirements/location_cnn.txt`).
- The MATLAB datasets live under `dataset/` (for example `dataset_SNR50_outdoor.mat`).
- Running the commands below writes artifacts under `checkpoints/`, `logs/`, `plots/`, and `artifacts/`.
- Each cell assumes `torch`, `matplotlib`, `numpy`, `pandas`, and `scikit-learn` are installed.

In [1]:
import sys
import importlib

print("Python", sys.version.split()[0])
for pkg in ["torch", "matplotlib", "numpy", "pandas", "sklearn"]:
    try:
        module = importlib.import_module(pkg)
        version = getattr(module, "__version__", "<builtin>")
        print(f"{pkg}: {version}")
    except Exception as exc:
        print(f"{pkg}: not installed ({exc})")

Python 3.12.12
torch: 2.3.1
matplotlib: 3.10.7
numpy: 1.26.4
pandas: 2.3.3
sklearn: 1.5.0


## Training the Location CNN

The `train_location_cnn.py` script normalizes the CSI dataset, trains the compact CNN, saves the checkpoint, and logs the metrics. You can override the dataset path via `--dataset-file`.

In [2]:
import sys
from pathlib import Path
from subprocess import run

train_cmd = [
    sys.executable,
    "train_location_cnn.py",
    "--epochs",
    "5",
    "--batch-size",
    "64",
    "--device",
    "cpu",
    "--retrain",
    "--metrics-path",
    str(Path("logs/location_training_metrics.csv")),
    "--checkpoint",
    str(Path("checkpoints/location_cnn.pt")),
    "--stats-path",
    str(Path("artifacts/feature_stats.json")),
]
print("Running:", " ".join(train_cmd))
run(train_cmd, check=True)

Running: /Users/rbkv2/src/csi_positioning/.venv/bin/python train_location_cnn.py --epochs 5 --batch-size 64 --device cpu --retrain --metrics-path logs/location_training_metrics.csv --checkpoint checkpoints/location_cnn.pt --stats-path artifacts/feature_stats.json


/opt/homebrew/Cellar/python@3.12/3.12.12/Frameworks/Python.framework/Versions/3.12/Resources/Python.app/Contents/MacOS/Python: can't open file '/Users/rbkv2/src/csi_positioning/notebooks/train_location_cnn.py': [Errno 2] No such file or directory


CalledProcessError: Command '['/Users/rbkv2/src/csi_positioning/.venv/bin/python', 'train_location_cnn.py', '--epochs', '5', '--batch-size', '64', '--device', 'cpu', '--retrain', '--metrics-path', 'logs/location_training_metrics.csv', '--checkpoint', 'checkpoints/location_cnn.pt', '--stats-path', 'artifacts/feature_stats.json']' returned non-zero exit status 2.

## Visualize training metrics

`plot_training_curves.py` reads the CSV metrics and writes PNG charts for loss and R². The next cells execute that script, display the saved PNGs, and render inline summaries.

In [None]:
import sys
from pathlib import Path
from subprocess import run
from IPython.display import Image, display

plots_dir = Path("plots")
plot_cmd = [
    sys.executable,
    "plot_training_curves.py",
    "--metrics-path",
    str(Path("logs/location_training_metrics.csv")),
    "--output-dir",
    str(plots_dir),
    "--dpi",
    "150",
]
print("Running:", " ".join(plot_cmd))
run(plot_cmd, check=True)

for image_name in [
    "training_loss.png",
    "validation_loss.png",
    "training_r2.png",
    "validation_r2.png",
]:
    image_path = plots_dir / image_name
    if image_path.exists():
        display(Image(filename=str(image_path)))
    else:
        print("Missing plot:", image_path)

In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

metrics_path = Path("logs/location_training_metrics.csv")
if metrics_path.exists():
    df = pd.read_csv(metrics_path)
    display(df.tail(5))

    fig, ax = plt.subplots(figsize=(8, 4))
    df.plot(x="epoch", y=["train_loss", "val_loss"], ax=ax, marker="o")
    ax.set_title("Loss by epoch")
    ax.set_ylabel("Loss")
    ax.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.show()

    fig, ax = plt.subplots(figsize=(8, 4))
    df.plot(x="epoch", y=["train_r2", "val_r2"], ax=ax, marker="o")
    ax.set_title("R² by epoch")
    ax.set_ylabel("R²")
    ax.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.show()
else:
    print("Metrics CSV not found; run the training cell first.")

## Inference

`infer_location.py` loads the latest checkpoint, runs the CNN on validation samples, and prints target/prediction pairs. Use `--fhe-mode`/`--quantized-module-path` when benchmarking encrypted execution.

In [None]:
import sys
from pathlib import Path
from subprocess import run

infer_cmd = [
    sys.executable,
    "infer_location.py",
    "--num-samples",
    "5",
    "--fhe-mode",
    "disable",
    "--quantized-module-path",
    str(Path("artifacts/location_quantized.pkl")),
]
print("Running:", " ".join(infer_cmd))
run(infer_cmd, check=True)