In [27]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
from torch.utils.data import DataLoader

In [24]:
class SignLanguageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.samples = []
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_path):
                self.samples.append((os.path.join(cls_path, img_name), self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")  # hoặc L nếu grayscale

        if self.transform:
            img = self.transform(img)

        return img, label


In [26]:
train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(p=0.5),        # flip ngang
    T.RandomRotation(degrees=15),         # xoay nhẹ
    T.ToTensor(),
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

test_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

In [28]:
train_dir = "/kaggle/working/vsl_split/train"
val_dir   = "/kaggle/working/vsl_split/valid"
test_dir  = "/kaggle/working/vsl_split/test"

train_ds = SignLanguageDataset(train_dir, transform=train_transform)
val_ds   = SignLanguageDataset(val_dir, transform=val_transform)
test_ds  = SignLanguageDataset(test_dir, transform=test_transform)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2)


In [29]:
imgs, labels = next(iter(train_loader))
print(imgs.shape, labels[:10])

torch.Size([32, 3, 224, 224]) tensor([ 2,  4,  7,  7,  8,  8, 21,  5, 21, 10])
