In [1]:
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera  # nếu dùng USB Camera

camera = CSICamera(width=224, height=224)
# camera = USBCamera(width=224, height=224)

camera.running = True


In [2]:
import os
import cv2
from torch.utils.data import Dataset

class ClassificationDataset(Dataset):
    def __init__(self, root_dir, categories, transform=None):
        self.root_dir = root_dir
        self.categories = categories
        self.transform = transform
        self.entries = []

        for label_index, label in enumerate(categories):
            class_dir = os.path.join(root_dir, label)
            os.makedirs(class_dir, exist_ok=True)
            for file_name in os.listdir(class_dir):
                if file_name.endswith(".jpg"):
                    self.entries.append((os.path.join(class_dir, file_name), label_index))

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        path, label = self.entries[idx]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image)

        return image, label

    def save_entry(self, category, image):
        count = self.get_count(category)
        filename = f"{category}_{count:04d}.jpg"
        path = os.path.join(self.root_dir, category, filename)
        cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    def get_count(self, category):
        folder = os.path.join(self.root_dir, category)
        return len([f for f in os.listdir(folder) if f.endswith(".jpg")])


In [3]:
import torchvision.transforms as transforms

TASK = 'traffic_signs'
CATEGORIES = ['stop', 'left', 'right', 'forward']

TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = ClassificationDataset(TASK, CATEGORIES, transform=TRANSFORMS)


In [4]:
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget

# Giao diện camera
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)

traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

# Giao diện điều khiển
category_widget = ipywidgets.Dropdown(options=CATEGORIES, description='label')
count_widget = ipywidgets.IntText(description='count')

# Update số lượng ảnh trong nhãn hiện tại
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])

category_widget.observe(update_counts, names='value')
count_widget.value = dataset.get_count(category_widget.value)


In [5]:
def save_snapshot(_, content, msg):
    if content['event'] == 'click':
        dataset.save_entry(category_widget.value, camera.value)

        snapshot = camera.value.copy()
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_widget.value = dataset.get_count(category_widget.value)

camera_widget.on_msg(save_snapshot)


In [6]:
data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    category_widget,
    count_widget
])

display(data_collection_widget)


VBox(children=(HBox(children=(ClickableImageWidget(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x0…