In [None]:
"""
Step 1: Collect and Load Data

Use the OpenTouch Interface to record your dataset.

* Make one recording per data class (e.g., one for each type of coin).
* Each `.touch` file should contain data for only one label/class.
* You’ll need to repeat the loading and preprocessing steps for all .touch files (one per label) before training.

Tip: Seeing a warning from Streamlit is normal and not an error.
"""
from opentouch_interface.decoder import Decoder

# Expects the full path e.g. "/home/someone/datasets/my_dataset.touch"
# Seeing a warning from Streamlit is normal
path = ...

dataset = Decoder(path)

In [None]:
"""
Step 2: Inspect the Raw Data

Now let’s check the structure of your recorded dataset.

* `dataset.sensor_names` (list[str]) lists all sensors that were captured.
* `dataset.stream_names_of(sensor_name)` (list[str]) lists the streams for a given sensor
  (for DIGIT this will just be "camera").
"""

print(f'The following sensors have been captured: {dataset.sensor_names}')
print(f'The sensors have the following streams:')
for sensor in dataset.sensor_names:
    print(f'\t- {sensor}: {dataset.stream_names_of(sensor)}')

In [None]:
"""
Step 3: Grab the Camera Frames

The raw dataset contains both the sensor data and additional metadata
(e.g., timestamps). For training we only need the actual frames.

* `dataset.stream_data_of(sensor_name, stream_name)` returns the list of frames.
* For DIGIT, the stream is `"camera"`, which gives you the captured images.

Hint: Call `dataset.stream_data_of` with `with_delta=False`.
"""

sensor_name = dataset.sensor_names[0]
data_stream = 'camera'
camera_data = dataset.stream_data_of(sensor_name, data_stream, with_delta=False)

In [None]:
"""
Step 4: Filter the Frames

Each dataset should only contain images of its respective label.
Remove frames that don’t match (e.g., "no touch" images in a "coin" dataset).

Why?
The raw data also includes unwanted frames (like empty touches or noise).
Filtering ensures that each dataset is clean and only contains the intended label.

* Exception: If you are creating a "no touch" dataset, keep the empty frames.
* Hint: You can do this both programmatically and using your file explorer.
"""

import numpy as np

no_touch = camera_data[:20]  # Assume the first 20 images don't show any touch. Adjust as needed.
avg_empty_image = np.mean(np.stack(no_touch, axis=0), axis=0)

def mean_square_error(image_a: np.ndarray, image_b: np.ndarray) -> float:
    diff = image_a - image_b
    return np.mean(diff ** 2)

threshold = 40.0  # <-- TODO: Adjust as needed
# print(mean_square_error(avg_empty_image, camera_data[100]))

with_touch = [frame for frame in camera_data if mean_square_error(frame, avg_empty_image) > threshold]
# with_touch = camera_data  # Use this when having a dataset with no touch
print(f'There are {len(with_touch)} images with recognized touch')

In [None]:
"""
Step 5: Save the Cleaned Dataset

Now save the filtered frames to disk.

* Saving them as `.png` files makes it easy to inspect the images in your file explorer.
* Each dataset (per label) will be stored in its own folder.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""

import os
from PIL import Image

dset_name = os.path.splitext(os.path.basename(path))[0]
directory = os.path.join('coin_data', dset_name)
os.makedirs(directory, exist_ok=True)

print(f'Saving {len(with_touch)} images to {directory}/')
for i, frame in enumerate(with_touch):
    img = Image.fromarray(frame.astype(np.uint8))
    img.save(os.path.join(directory, f'{dset_name}_{i:04d}.png'))

In [None]:
"""
Create a simple model to convert RGB images to grayscale.

* The filter converts RGB images to grayscale using the standard luminosity method.
* It inherits from `BaseFilter` and implements `forward` and `onnx_export`.
* Finally, the model is saved to disk for later use.
"""

from typing import Dict, Any

import numpy as np
import torch.onnx

from opentouch.core.base_filter import BaseFilter


class GrayscaleFilter(BaseFilter):
    """
    A simple model that converts an input image tensor to a grayscale image.
    """

    def __init__(self, height: int, width: int):
        super().__init__()
        self.height: int = height
        self.width: int = width

    @property
    def description(self) -> str:
        return "A model that converts an input RGB image to a grayscale image."

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to convert the input tensor to grayscale.

        Args:
            x (torch.Tensor): Input tensor of shape (N, H, W, C) from CV2, where C is expected to be 3 (RGB channels).
        Returns:
            torch.Tensor: Grayscale image tensor of shape (N, 1, H, W).
        """

        # Convert from (N, H, W, C) to (N, C, H, W) format for PyTorch processing
        x = x.permute(0, 3, 1, 2)

        # Convert to float if input is uint8
        if x.dtype == torch.uint8:
            x = x.float() / 255.0

        # Convert RGB to grayscale using the standard luminosity method
        # Grayscale = 0.2989 * R + 0.5870 * G + 0.1140 * B
        r, g, b = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:3, :, :]
        grayscale = 0.2989 * r + 0.5870 * g + 0.1140 * b

        return grayscale


    def onnx_export(self) -> Dict[str, Any]:
        """
        Constructs the parameters needed for torch.onnx.export().
        """

        return {
            'example_input': torch.randint(0, 256, (1, 320, 240, 3), dtype=torch.uint8),
            'input_names': ['input'],
            'output_names': ['output'],
        }

# Saving the model (if not already saved)
gray_filter = GrayscaleFilter(height=320, width=240)
gray_filter.save("grayscale_filter")

In [None]:
"""
Step 6: Define the CNN Model

We now build a Convolutional Neural Network (CNN) to classify the coins.

* The model takes RGB frames as input (3 channels).
* It outputs one class per label in your dataset.
* Preprocessing converts images from [N, H, W, C] to [N, C, H, W]
  and normalizes pixel values to [0, 1].
"""

from typing import Dict, Any

from torch import Tensor

from opentouch.core.base_cnn import BaseCNN


class CoinClassifier(BaseCNN):
    def __init__(self, label_mapping: dict) -> None:
        super().__init__(input_channels=3, label_mapping=label_mapping)

    @property
    def description(self) -> str:
        labels = ', '.join(self.label_mapping.values())
        return f"A CNN classifier for distinguishing between: {labels}"

    def build(self) -> None:
        """CNN architecture for coin classification"""
        self.model = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 320x240 -> 160x120

            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 160x120 -> 80x60

            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 80x60 -> 40x30

            # Global average pooling and classifier
            nn.AdaptiveAvgPool2d(1),  # 40x30 -> 1x1
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, self.num_classes)
        )

    def preprocess(self, x: Tensor) -> Tensor:
        """
        Preprocess input images from [N, H, W, C] to [N, C, H, W] format.
        Normalizes pixel values from [0, 255] to [0, 1].
        """
        x = x.float() / 255.0
        return x.permute(0, 3, 1, 2)

    def onnx_export(self) -> Dict[str, Any]:
        """Defines parameters needed for ONNX export."""
        return {
            'example_input': torch.randint(0, 256, (1, 320, 240, 3), dtype=torch.uint8),
            'input_names': ['input'],
            'output_names': ['output'],
        }

In [None]:
"""
This is an alternative to step 6 where we use a pre-trained model from PyTorch.
"""
from typing import Dict, Any

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

from opentouch.core.base_cnn import BaseCNN
from opentouch.core.model_loader import ModelLoader


class CoinClassifierEfficientNet(BaseCNN):
    def __init__(self, label_mapping: dict) -> None:
        super().__init__(input_channels=3, label_mapping=label_mapping)

    @property
    def description(self) -> str:
        labels = ', '.join(self.label_mapping.values())
        return f"EfficientNet-B4 transfer learning classifier for: {labels}"

    def build(self) -> None:
        """Build EfficientNet with frozen backbone, only train classifier"""
        from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights

        weights = EfficientNet_B4_Weights.DEFAULT
        backbone = efficientnet_b4(weights=weights)

        # Freeze ALL backbone parameters - only train the classifier
        for param in backbone.parameters():
            param.requires_grad = False

        # Replace classifier head
        backbone.classifier[1] = nn.Linear(
            backbone.classifier[1].in_features,
            self.num_classes,
            bias=True
        )

        self.model = backbone

    def preprocess(self, x: Tensor) -> Tensor:
        """
        EfficientNet preprocessing with aspect ratio preservation
        Input: [N, H, W, C] uint8 [0, 255]
        Output: [N, C, H, W] float32 normalized
        """

        x = x.float() / 255.0  # Convert to float and normalize to [0, 1]
        x = x.permute(0, 3, 1, 2)  # Permute to [N, C, H, W]

        # Resize to EfficientNet input size (380x380 for B4) with aspect ratio preservation
        inout_h, input_w = 240, 320
        target_size = 380

        # Calculate scale to fit within target size
        scale = min(target_size / inout_h, target_size / input_w)
        new_h, new_w = int(inout_h * scale), int(input_w * scale)

        # Resize maintaining aspect ratio
        x = torch.nn.functional.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)

        # Pad to square (center the image)
        pad_h = target_size - new_h
        pad_w = target_size - new_w
        pad_top = pad_h // 2
        pad_left = pad_w // 2
        pad_bottom = pad_h - pad_top
        pad_right = pad_w - pad_left

        x = torch.nn.functional.pad(x, (pad_left, pad_right, pad_top, pad_bottom))

        # ImageNet normalization (what EfficientNet expects)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
        x = (x - mean) / std

        return x

    def onnx_export(self) -> Dict[str, Any]:
        """Defines parameters needed for ONNX export."""
        return {
            'example_input': torch.randint(0, 256, (1, 320, 240, 3), dtype=torch.uint8),
            'input_names': ['input'],
            'output_names': ['output'],
        }

In [None]:
"""
Step 7: Load the Coin Datasets

Now we load the saved images back into memory.

* Each coin type should be in its own subdirectory under `coin_data/`.
* A label mapping is created automatically from the folder names.
* Optionally, datasets are balanced so all classes have the same number of images.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""
import os
from PIL import Image

def load_coin_datasets(dset_path: str, balance_datasets = True, max_size=2000) -> tuple[np.ndarray, np.ndarray, dict[int, str]]:
    """
    Load coin images from subdirectories

    Expected structure:
    coin_data/
    ├── two_euro/
    │   ├── two_euro_0001.png
    │   └── ...
    """
    if not os.path.exists(dset_path):
        raise FileNotFoundError(f"Directory '{dset_path}' does not exist.")

    coin_dirs = [d for d in os.listdir(dset_path) if os.path.isdir(os.path.join(dset_path, d))]
    coin_dirs.sort()

    if not coin_dirs:
        raise ValueError(f"No subdirectories found in the directory '{dset_path}'.")

    # Create label mapping
    labels = {j: coin_name for j, coin_name in enumerate(coin_dirs)}
    print(f"Found {len(coin_dirs)} coin types:")
    for label, name in labels.items():
        print(f"  Label {label}: {name}")

    # Load all images
    coin_images = {}  # label -> list of images
    for label, coin_name in labels.items():
        coin_path = os.path.join(dset_path, coin_name)

        all_files = os.listdir(coin_path)
        image_files = [f for f in all_files if f.lower().endswith('.png')]
        image_files.sort()

        images = []
        for img_file in image_files:
            img_path = os.path.join(coin_path, img_file)
            img = Image.open(img_path)
            img_array = np.array(img)
            images.append(img_array)
        coin_images[label] = images

    # Balance datasets
    all_images, all_labels = [], []
    if balance_datasets:
        smallest_dset = min(min(len(images) for images in coin_images.values()), max_size)
        for label, images in coin_images.items():
            all_images.extend(images[:smallest_dset])
            all_labels.extend([label] * smallest_dset)
    else:
        for label, images in coin_images.items():
            all_images.extend(images)
            all_labels.extend([label] * len(images))

    # Convert to numpy arrays
    X = np.stack(all_images, axis=0)  # Shape: (N, H, W, C)
    Y = np.array(all_labels)  # Shape: (N,)

    return X, Y, labels

In [None]:
"""
Step 8: Train and Save the Model

Now we bring everything together:

* Load the datasets and convert them to PyTorch tensors.
* Create a DataLoader for batching and shuffling.
* Initialize and train the CNN.
* Save the trained model to disk.
"""

# Load datasets
X, y, label_mapping = load_coin_datasets('coin_data')

# Convert to PyTorch tensors
X_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).long()

# Create dataset and dataloader
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create the model
# model = CoinClassifier(label_mapping=label_mapping)
model = CoinClassifierEfficientNet(label_mapping=label_mapping)
model.compile()

# Train the model
model.fit(dataloader, num_epochs=20, log_interval=5)

# Save the trained model
model.save("coin_classifier")
print("Training complete! Model saved as coin_classifier.zip")

In [None]:
"""
Step 9: Load and Inspect the Model

After training, you can reload the saved model and check its metadata.
"""

session = ModelLoader.from_path("coin_classifier.zip")

print(f'Model type: {session.model_type}')
print(f'Description: {session.description}')
print(f'Label mapping: {session.label_mapping}')

In [None]:
"""
Step 10: Predict from a Single Image

We now define a helper function to classify one image with the trained model.

* The image is expanded with a batch dimension before inference.
* The ONNX session returns model outputs.
* The highest-scoring class is mapped back to its label.

DON'T MODIFY THIS CELL. SIMPLY RUN.
"""

import onnxruntime as ort

def predict_coin(image: np.ndarray, session: ort.InferenceSession) -> str:
    """Predict coin type from a single image."""
    # Add batch dimension and run inference
    input_batch = np.expand_dims(image, axis=0)
    output = session.run(['output'], {'input': input_batch})[0]

    # Get prediction and convert to label
    predicted_class = np.argmax(output[0])
    predicted_label = session.label_mapping[str(predicted_class)]

    return predicted_label

In [None]:
"""
Step 11: Test the Model

Finally, let’s check the trained model on a few random images.

* Pick random samples from the dataset.
* Run predictions with `predict_coin`.
* Compare predicted vs. true labels.
"""

import random

for i in range(20):
    idx = random.randint(0, len(X) - 1)
    test_image = X[idx]
    true_label = label_mapping[y[idx]]

    predicted_label = predict_coin(test_image, session)
    print(f"True label: {true_label} | Predicted label: {predicted_label} | {true_label == predicted_label}")

In [None]:
"""
Step 12: Evaluate Accuracy

Instead of just printing individual results,
we can calculate overall accuracy across 100 random images.
"""

import random

num_tests = 100
correct = 0

for i in range(num_tests):
    idx = random.randint(0, len(X) - 1)
    test_image = X[idx]
    true_label = label_mapping[y[idx]]
    predicted_label = predict_coin(test_image, session)

    if predicted_label == true_label:
        correct += 1

accuracy = correct / num_tests
print(f"Accuracy over {num_tests} random samples: {accuracy:.2%}")