In [None]:
"""
Step 1: Collect and Load Data

Use the OpenTouch Interface to record your dataset.

* Make one recording per data class (e.g., one for each type of coin).
* Each `.touch` file should contain data for only one label/class.
* You’ll need to repeat the loading and preprocessing steps for all .touch files (one per label) before training.

Tip: Seeing a warning from Streamlit is normal and not an error.
"""

from opentouch_interface.decoder import Decoder

# Replace this with the full path to one of your dataset files
# Example: "/home/username/datasets/coin1.touch"
path = ...

# Load the dataset for this label
dataset = Decoder(path)


In [None]:
"""
Step 2: Inspect the Raw Data

Now let’s check the structure of your recorded dataset.

* `dataset.sensor_names` (list[str]) lists all sensors that were captured.
* `dataset.stream_names_of(sensor_name)` (list[str]) lists the streams for a given sensor
  (for DIGIT this will just be "camera").
"""

# TODO: Inspect the raw data by printing the stream names for each sensor


In [None]:
"""
Step 3: Grab the Camera Frames

The raw dataset contains both the sensor data and additional metadata
(e.g., timestamps). For training we only need the actual frames.

* `dataset.stream_data_of(sensor_name, stream_name)` returns the list of frames.
* For DIGIT, the stream is `"camera"`, which gives you the captured images.

Hint: Call `dataset.stream_data_of` with `with_delta=False`.
"""
# TODO: Extract the camera frames from the raw data
camera_data: list = [...]


In [None]:
"""
Step 4: Filter the Frames

Each dataset should only contain images of its respective label.
Remove frames that don’t match (e.g., "no touch" images in a "coin" dataset).

Why?
The raw data also includes unwanted frames (like empty touches or noise).
Filtering ensures that each dataset is clean and only contains the intended label.

* Exception: If you are creating a "no touch" dataset, keep the empty frames.
* Hint: You can do this both programmatically and using your file explorer.
"""

import numpy as np

# TODO: Filter the frames to only include frames of the correct label
with_touch: list = [...]


In [None]:
"""
Step 5: Save the Cleaned Dataset

Now save the filtered frames to disk.

* Saving them as `.png` files makes it easy to inspect the images in your file explorer.
* Each dataset (per label) will be stored in its own folder.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""

import os
from PIL import Image

# Use the .touch filename (without extension) as the dataset name
dset_name = os.path.splitext(os.path.basename(path))[0]
directory = os.path.join("coin_data", dset_name)
os.makedirs(directory, exist_ok=True)

print(f"Saving {len(with_touch)} images to {directory}/")
for i, frame in enumerate(with_touch):
    img = Image.fromarray(frame.astype(np.uint8))
    img.save(os.path.join(directory, f"{dset_name}_{i:04d}.png"))


In [None]:
"""
Step 6: Define the CNN Model

We now build a Convolutional Neural Network (CNN) to classify the coins.

* The model takes RGB frames as input (3 channels).
* It outputs one class per label in your dataset.
* Preprocessing converts images from [N, H, W, C] to [N, C, H, W]
  and normalizes pixel values to [0, 1].
"""

import torch
import torch.nn as nn
from torch import Tensor
from typing import Dict, Any

from opentouch.core.base_cnn import BaseCNN


class CoinClassifier(BaseCNN):
    def __init__(self, label_mapping: dict) -> None:
        super().__init__(input_channels=3, label_mapping=label_mapping)

    @property
    def description(self) -> str:
        labels = ", ".join(self.label_mapping.values())
        return f"A CNN classifier for distinguishing between: {labels}"

    def build(self) -> None:
        """
        Define CNN architecture for coin classification.

        e.g., self.model = nn.Sequential(...)
        """

        # TODO: Define the CNN architecture

        self.model = ...

    def preprocess(self, x: Tensor) -> Tensor:
        """
        Convert [N, H, W, C] → [N, C, H, W], normalize to [0, 1].

        Don't modify this part.
        """
        x = x.float() / 255.0
        return x.permute(0, 3, 1, 2)

    def onnx_export(self) -> Dict[str, Any]:
        """
        Parameters for ONNX export.

        Don't modify this part if you use a standard DIGIT frame as your model input.
        """
        return {
            "example_input": torch.randint(0, 256, (1, 320, 240, 3), dtype=torch.uint8),
            "input_names": ["input"],
            "output_names": ["output"],
        }

In [None]:
"""
Step 7: Load the Coin Datasets

Now we load the saved images back into memory.

* Each coin type should be in its own subdirectory under `coin_data/`.
* A label mapping is created automatically from the folder names.
* Optionally, datasets are balanced so all classes have the same number of images.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""

import os
import numpy as np
from PIL import Image

def load_coin_datasets(dset_path: str, balance_datasets: bool = True, max_size=2000) -> tuple[np.ndarray, np.ndarray, dict[int, str]]:
    """
    Load coin images from subdirectories.

    Expected structure:
    coin_data/
    ├── two_euro/
    │   ├── two_euro_0001.png
    │   └── ...
    ├── one_euro/
    │   ├── one_euro_0001.png
    │   └── ...
    """

    if not os.path.exists(dset_path):
        raise FileNotFoundError(f"Directory '{dset_path}' does not exist.")

    coin_dirs = [d for d in os.listdir(dset_path) if os.path.isdir(os.path.join(dset_path, d))]
    coin_dirs.sort()
    if not coin_dirs:
        raise ValueError(f"No subdirectories found in '{dset_path}'.")

    # Create label mapping
    labels = {j: coin_name for j, coin_name in enumerate(coin_dirs)}
    print(f"Found {len(coin_dirs)} coin types:")
    for label, name in labels.items():
        print(f"  Label {label}: {name}")

    # Load images
    coin_images = {}
    for label, coin_name in labels.items():
        coin_path = os.path.join(dset_path, coin_name)
        image_files = sorted([f for f in os.listdir(coin_path) if f.lower().endswith(".png")])

        images = [np.array(Image.open(os.path.join(coin_path, f))) for f in image_files]
        coin_images[label] = images

    # Balance datasets (optional)
    all_images, all_labels = [], []
    if balance_datasets:
        smallest_dset = min(min(len(images) for images in coin_images.values()), max_size)
        for label, images in coin_images.items():
            all_images.extend(images[:smallest_dset])
            all_labels.extend([label] * smallest_dset)
    else:
        for label, images in coin_images.items():
            all_images.extend(images)
            all_labels.extend([label] * len(images))

    # Convert to numpy arrays
    X = np.stack(all_images, axis=0)  # Shape: (N, H, W, C)
    Y = np.array(all_labels)          # Shape: (N,)

    return X, Y, labels


In [None]:
"""
Step 8: Train and Save the Model

Now we bring everything together:

* Load the datasets and convert them to PyTorch tensors.
* Create a DataLoader for batching and shuffling.
* Initialize and train the CNN.
* Save the trained model to disk.
"""

from torch.utils.data import TensorDataset, DataLoader

# TODO: Load datasets with `load_coin_datasets("coin_data")

# TODO: Convert the images and labels to PyTorch tensors
X_tensor = ...
y_tensor = ...

# TODO: Wrap the tensors in a TensorDataset and create a DataLoader
dataset = ...
dataloader = ...  # Select a batch size that fits your GPU memory

# TODO: Initialize and compile the model
model = ...
model.compile()

# TODO: Train the model using `model.fit(...)`. Lookup the needed arguments

# TODO: Save the model using `model.save(...)`.


In [None]:
"""
Step 9: Load and Inspect the Model

After training, you can reload the saved model and check its metadata.
"""

from opentouch.core.model_loader import ModelLoader

# TODO: Load the model using `ModelLoader.from_path('<model_name>.zip')`
session = ...

print(f"Model type: {session.model_type}")
print(f"Description: {session.description}")
print(f"Label mapping: {session.label_mapping}")


In [None]:
"""
Step 10: Predict from a Single Image

We now define a helper function to classify one image with the trained model.

* The image is expanded with a batch dimension before inference.
* The ONNX session returns model outputs.
* The highest-scoring class is mapped back to its label.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""

import onnxruntime as ort
import numpy as np

def predict_coin(image: np.ndarray, session: ort.InferenceSession) -> str:
    """Predict coin type from a single image."""
    # Add batch dimension [H, W, C] -> [1, H, W, C]
    input_batch = np.expand_dims(image, axis=0)

    # Run inference
    output = session.run(["output"], {"input": input_batch})[0]

    # Pick the class with highest score
    predicted_class = int(np.argmax(output[0]))

    # Convert index back to label name
    predicted_label = session.label_mapping[str(predicted_class)]

    return predicted_label

In [None]:
"""
Step 11: Test the Model

Finally, let’s check the trained model on a few random images.

* Pick random samples from the dataset.
* Run predictions with `predict_coin`.
* Compare predicted vs. true labels.
"""

import random

for i in range(20):
    idx = random.randint(...)
    test_image = ...
    true_label = ...
    predicted_label = ...

    print(f"True: {true_label} | Predicted: {predicted_label} | Correct: {true_label == predicted_label}")


In [None]:
"""
Step 12: Evaluate Accuracy

Instead of just printing individual results,
we can calculate overall accuracy across 100 random images.
"""

import random

num_tests = 100
correct = 0

for i in range(num_tests):
    idx = random.randint(...)
    test_image = ...
    true_label = ...
    predicted_label = predict_coin(test_image, session)

    if predicted_label == true_label:
        correct += 1

accuracy = correct / num_tests
print(f"Accuracy over {num_tests} random samples: {accuracy:.2%}")
