# Interactive Image Classification: Drive vs Stop

This notebook provides an interactive workflow for collecting, training, and deploying a binary image classifier that distinguishes between "drive" and "stop" using live camera input.

## 1. Camera Initialization

Import and initialize the camera for capturing live images. You can use either the CSI or USB camera depending on your hardware.

In [None]:
# --- Grad-CAM Visualization Cell ---

# Install torchcam if not already installed (uncomment if needed)
# !pip install torchcam

from torchcam.methods import SmoothGradCAMpp
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt

# Initialize the CAM extractor for the model (after model definition)
cam_extractor = SmoothGradCAMpp(model)

def show_gradcam(image_array):
    model.eval()
    input_image = Image.fromarray(image_array)
    input_tensor = preprocess(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(input_tensor)
        pred = out.argmax(dim=1).item()
    # Extract CAM
    activation_map = cam_extractor(pred, out)
    cam = activation_map[0].cpu().numpy()
    # Show original image and CAM overlay
    fig, ax = plt.subplots(1, 2, figsize=(10,5))
    ax[0].imshow(input_image)
    ax[0].set_title("Original Image")
    ax[1].imshow(input_image)
    ax[1].imshow(cam, cmap='jet', alpha=0.5)
    ax[1].set_title("Grad-CAM Overlay")
    plt.show()

# Example usage: show_gradcam(camera.value)

In [None]:
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera

camera = CSICamera(width=224, height=224)
# camera = USBCamera(width=224, height=224)

camera.running = True

## 2. Task and Dataset Preparation

Define the classification task, set up the two classes ('drive', 'stop'), and prepare the dataset and transforms for image preprocessing.

In [None]:
import torchvision.transforms as transforms
from xy_dataset import XYDataset  # You may need to adapt this or use a custom dataset for classification

TASK = 'drive_stop_classification'
CLASSES = ['drive', 'stop']
DATASETS = ['A', 'B']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# You may need to implement a simple classification dataset if XYDataset is for regression
from torch.utils.data import Dataset
import os
from PIL import Image

class ClassificationDataset(Dataset):
    def __init__(self, root, classes, transform=None):
        self.root = root
        self.classes = classes
        self.transform = transform
        self.samples = []
        for idx, cls in enumerate(classes):
            cls_dir = os.path.join(root, cls)
            if os.path.exists(cls_dir):
                for fname in os.listdir(cls_dir):
                    if fname.endswith('.jpg') or fname.endswith('.png'):
                        self.samples.append((os.path.join(cls_dir, fname), idx))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label
    def add_sample(self, class_name, image_array):
        cls_dir = os.path.join(self.root, class_name)
        os.makedirs(cls_dir, exist_ok=True)
        idx = len(os.listdir(cls_dir))
        fname = os.path.join(cls_dir, f"{idx:04d}.jpg")
        Image.fromarray(image_array).save(fname)
        self.samples.append((fname, self.classes.index(class_name)))
    def get_count(self, class_name):
        cls_dir = os.path.join(self.root, class_name)
        if os.path.exists(cls_dir):
            return len(os.listdir(cls_dir))
        return 0

datasets = {}
for name in DATASETS:
    datasets[name] = ClassificationDataset(TASK + '_' + name, CLASSES, TRANSFORMS)

## 3. Data Collection Widget for Classification

Create widgets to capture images from the camera and label them as either 'drive' or 'stop'. Save labeled images to the dataset.

In [None]:
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget

# initialize active dataset
dataset = datasets[DATASETS[0]]

# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()

# create image preview
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

# create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
class_widget = ipywidgets.Dropdown(options=CLASSES, description='class')
count_widget = ipywidgets.IntText(description='count')

# manually update counts at initialization
count_widget.value = dataset.get_count(class_widget.value)

def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value = dataset.get_count(class_widget.value)
dataset_widget.observe(set_dataset, names='value')

def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
class_widget.observe(update_counts, names='value')

def save_snapshot(_=None):
    # save to disk
    dataset.add_sample(class_widget.value, camera.value)
    # display saved snapshot
    snapshot = camera.value.copy()
    snapshot_widget.value = bgr8_to_jpeg(snapshot)
    count_widget.value = dataset.get_count(class_widget.value)

save_button = ipywidgets.Button(description='Save Image')
save_button.on_click(lambda _: save_snapshot())

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    dataset_widget,
    class_widget,
    count_widget,
    save_button
])

display(data_collection_widget)

## 4. Model Definition and Management

Define a neural network model (e.g., ResNet18) for binary classification, and provide widgets to save/load model weights.

In [None]:
import torch
import torchvision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dim = 2  # drive, stop

model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='drive_stop_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value, map_location=device))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

display(model_widget)

## 5. Live Classification Execution

Implement a live preview that runs the model on camera images and displays the predicted class in real time.

In [None]:
import threading
import time
from torchvision import transforms as T

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
predicted_class_widget = ipywidgets.Label(value="Prediction: -")

preprocess = TRANSFORMS

def live_classification(state_widget, model, camera, prediction_widget, predicted_class_widget):
    while state_widget.value == 'live':
        image = camera.value
        input_image = Image.fromarray(image)
        input_tensor = preprocess(input_image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            pred_idx = torch.argmax(output, dim=1).item()
            pred_class = CLASSES[pred_idx]
        # Draw predicted class on image
        display_image = image.copy()
        cv2.putText(display_image, pred_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
        prediction_widget.value = bgr8_to_jpeg(display_image)
        predicted_class_widget.value = f"Prediction: {pred_class}"
        time.sleep(0.1)

def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live_classification, args=(state_widget, model, camera, prediction_widget, predicted_class_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    predicted_class_widget,
    state_widget
])

display(live_execution_widget)

## 6. Training and Evaluation Controls

Add widgets to control training (epochs, train/evaluate buttons, progress/loss display) and implement the training/evaluation loop for classification.

In [None]:
BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
accuracy_widget = ipywidgets.FloatText(description='accuracy')

def train_eval(is_training):
    global BATCH_SIZE, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget

    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model.train()
        else:
            model.eval()

        for epoch in range(epochs_widget.value if is_training else 1):
            i = 0
            sum_loss = 0.0
            correct = 0
            total = 0
            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.to(device)

                if is_training:
                    optimizer.zero_grad()

                outputs = model(images)
                loss = criterion(outputs, labels)

                if is_training:
                    loss.backward()
                    optimizer.step()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                i += labels.size(0)
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                accuracy_widget.value = correct / total

    except Exception as e:
        print(e)
    model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'

train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))

train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

## 7. Combined Widget Display

Combine all widgets into a single interface for streamlined data collection, training, and live inference.

In [None]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

display(all_widget)