In [1]:
import os
import random
from collections import defaultdict

input_txt = '/kaggle/input/keypoint/videoid_label.txt'
out_train = '/kaggle/working/train.txt'
out_val = '/kaggle/working/val.txt'
out_test = '/kaggle/working/test.txt'

#os.makedirs(out_train, exist_ok=True)
#os.makedirs(out_val, exist_ok=True)
#os.makedirs(out_test, exist_ok=True)
# Đọc file và gom theo nhãn
label_dict = defaultdict(list)
with open(input_txt, 'r') as f:
    for line in f:
        line = line.strip()
        if not line: continue
        video_id, label = line.split()
        label_dict[label].append(video_id)

# Chia tỉ lệ 6:4:4 cho từng nhãn
train_lines, val_lines, test_lines = [], [], []
for label, vids in label_dict.items():
    vids = list(vids)
    random.shuffle(vids)
    n = len(vids)
    n_train = round(n * 0.6)
    n_val = round(n * 0.2)
    n_test = n - n_train - n_val
    train = vids[:n_train]
    val = vids[n_train:n_train+n_val]
    test = vids[n_train+n_val:]
    train_lines.extend([f"{vid} {label}\n" for vid in train])
    val_lines.extend([f"{vid} {label}\n" for vid in val])
    test_lines.extend([f"{vid} {label}\n" for vid in test])

# Shuffle lại từng tập để tránh cùng nhãn đứng liền nhau
random.shuffle(train_lines)
random.shuffle(val_lines)
random.shuffle(test_lines)

# Lưu file
with open(out_train, 'w') as f: f.writelines(train_lines)
with open(out_val, 'w') as f: f.writelines(val_lines)
with open(out_test, 'w') as f: f.writelines(test_lines)

print("Đã chia xong. Train:", len(train_lines), "Val:", len(val_lines), "Test:", len(test_lines))

Đã chia xong. Train: 456 Val: 145 Test: 150


# 1. Pretrain Vision encoder

In [2]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import torchvision
from tqdm import tqdm
import json

In [3]:
def build_label_map(txt_files):
    labels = set()
    for txt in txt_files:
        with open(txt, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    label = ' '.join(parts[1:]).strip()
                    labels.add(label)
    labels = sorted(labels)
    label_map = {lbl: idx for idx, lbl in enumerate(labels)}
    return label_map

In [4]:
data_root = "/kaggle/input/image-mask/cropped_hands"
train_txt = "/kaggle/working/train.txt"
valid_txt = "/kaggle/working/val.txt"
test_txt = "/kaggle/working/test.txt"
batch_size = 32
num_workers = 4
num_epochs = 6
lr = 1e-3
out_ckpt = "/kaggle/working/pretrained_hand_rgb.pth"
out_labelmap = "/kaggle/working/label_map.json"

label_map = build_label_map([train_txt, valid_txt, test_txt])
num_classes = len(label_map)
print("Số lớp:", num_classes)
print("Sample label_map:", dict(list(label_map.items())[:5]))

with open(out_labelmap, "w") as f:
    json.dump(label_map, f)

Số lớp: 32
Sample label_map: {'all': 0, 'before': 1, 'black': 2, 'book': 3, 'candy': 4}


In [5]:
# ==== 2. Dataset for cropped hand images ====
class HandImageDataset(Dataset):
    def __init__(self, data_root, list_file, label_map, transform=None):
        self.samples = []
        self.transform = transform
        with open(list_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 2:
                    video_id, label = parts
                    label_idx = label_map[label]
                    #img_dir = os.path.join(data_root, video_id)
                    # All *_left.jpg and *_right.jpg (can add filter for frame sampling if needed)
                    imgs = sorted(glob(os.path.join(data_root, f"{video_id}_*_left.jpg"))) + \
                           sorted(glob(os.path.join(data_root, f"{video_id}_*_right.jpg")))
                    for img_path in imgs:
                        if os.path.isfile(img_path):
                            self.samples.append((img_path, label_idx))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [6]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

train_ds = HandImageDataset(data_root, train_txt, label_map, transform=transform)
val_ds = HandImageDataset(data_root, valid_txt, label_map, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print(f"Số ảnh train: {len(train_ds)}, Số ảnh val: {len(val_ds)}")

Số ảnh train: 2266, Số ảnh val: 732


In [7]:
# ==== 3. Vision Model ====
class VisionClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        backbone = torchvision.models.efficientnet_b0(pretrained=True)
        self.features = backbone.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(1280, num_classes)
    def forward(self, x):
        feat = self.features(x)
        feat = self.pool(feat).view(x.size(0), -1)
        return self.classifier(feat)


In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionClassifier(num_classes)
if torch.cuda.device_count() > 1:
    print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
    model = nn.DataParallel(model)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 168MB/s]


Using DataParallel with 2 GPUs


In [9]:
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}", leave=False):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    train_acc = correct / total
    train_loss = total_loss / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Valid", leave=False):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(1)
            val_correct += (preds == labels).sum().item()
            val_total += imgs.size(0)
    val_acc = val_correct / val_total
    val_loss = val_loss / val_total

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model": model.state_dict(),
            "label_map": label_map
        }, out_ckpt)
        print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

print("Done. Best val acc:", best_val_acc)

                                                              

Epoch 1/6 | Train Loss: 3.0918 Acc: 0.1703 | Val Loss: 2.9388 Acc: 0.1967
Best model saved at epoch 1, val_acc=0.1967


                                                              

Epoch 2/6 | Train Loss: 2.4080 Acc: 0.3455 | Val Loss: 2.8100 Acc: 0.2842
Best model saved at epoch 2, val_acc=0.2842


                                                              

Epoch 3/6 | Train Loss: 1.9299 Acc: 0.4590 | Val Loss: 2.9182 Acc: 0.2732


                                                              

Epoch 4/6 | Train Loss: 1.5556 Acc: 0.5556 | Val Loss: 2.8235 Acc: 0.3265
Best model saved at epoch 4, val_acc=0.3265


                                                              

Epoch 5/6 | Train Loss: 1.2423 Acc: 0.6390 | Val Loss: 3.0118 Acc: 0.3183


                                                              

Epoch 6/6 | Train Loss: 0.9920 Acc: 0.7039 | Val Loss: 3.3368 Acc: 0.3156
Done. Best val acc: 0.32650273224043713




# 2. pretrain STGCN

In [10]:
import numpy as np
import torch
from torch.utils.data import Dataset

class PoseSpatialPartDataset(Dataset):
    def __init__(self, data_root, txt_file, label_map, part='body'):
        self.samples = []
        self.label_map = label_map
        self.part = part
        with open(txt_file) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 2:
                    print(f"WARNING: dòng bị lỗi format: {line}")
                    continue
                npy_id = parts[0]
                label = ' '.join(parts[1:]).strip()
                if label not in label_map:
                    print(f"WARNING: label '{label}' chưa có trong label_map!")
                    continue
                npy_path = f"{data_root}/{npy_id}_keypoint.npy"
                self.samples.append((npy_path, int(label_map[label])))
        # Define keypoint slices for each part
        if part == 'body':
            self.idx_start, self.idx_end = 0, 25
        elif part == 'left':
            self.idx_start, self.idx_end = 25, 46
        elif part == 'right':
            self.idx_start, self.idx_end = 46, 67
        else:
            raise ValueError("Unknown part: " + part)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npy_path, label = self.samples[idx]
        keypoints = np.load(npy_path)  # (T, 67, 3) expected
        # Auto fix shape if needed
        if keypoints.shape[-2:] == (67, 3):
            pass
        elif keypoints.shape[0] == 67 and keypoints.shape[1] == 3:
            keypoints = np.transpose(keypoints, (2, 0, 1))
        elif keypoints.shape[1] == 3 and keypoints.shape[2] == 67:
            keypoints = np.transpose(keypoints, (0, 2, 1))
        else:
            raise RuntimeError(f"Unrecognized keypoints shape: {keypoints.shape}")
        part_kp = keypoints[:, self.idx_start:self.idx_end, :]  # (T, N, 3)
        if idx == 0:
            print(f"Dataset part_kp.shape: {part_kp.shape}")
        return torch.tensor(part_kp, dtype=torch.float32), label

In [11]:
class SpatialGCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, A):
        super().__init__()
        self.register_buffer('A', A)
        self.fc = nn.Linear(in_channels, out_channels)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):  # x: (B, N, in_channels)
        h = self.fc(x)  # (B, N, out_channels)
        h = h.permute(0, 2, 1)  # (B, out_channels, N)
        h = torch.matmul(h, self.A)  # (B, out_channels, N)
        h = self.bn(h)  # BatchNorm trên out_channels
        h = h.permute(0, 2, 1)  # (B, N, out_channels)
        return torch.relu(h)

In [12]:
class SpatialPoseEncoder(nn.Module):
    def __init__(self, in_channels, num_joints, num_classes, A, hid_dim=128, out_dim=256):
        super().__init__()
        self.gcn1 = SpatialGCNLayer(in_channels, hid_dim, A)
        self.gcn2 = SpatialGCNLayer(hid_dim, out_dim, A)
        self.classifier = nn.Linear(out_dim * num_joints, num_classes)

    def forward(self, x):  # x: (B, T, N, 3)
        #print("Encoder x.shape:", x.shape)
        B, T, N, C = x.shape
        assert N == self.gcn1.A.shape[0], f"x.shape={x.shape}, A.shape={self.gcn1.A.shape}"
        x = x.view(B * T, N, C)  # (B*T, N, 3)
        h = self.gcn1(x)         # (B*T, N, hid_dim)
        h = self.gcn2(h)         # (B*T, N, out_dim)
        h = h.view(B, T, N, -1)  # (B, T, N, out_dim)
        h = h.mean(1)            # (B, N, out_dim)
        h = h.permute(0, 2, 1)   # (B, out_dim, N)
        logits = self.classifier(h.flatten(1))  # (B, num_classes)
        return logits, h

In [13]:
def gather_special_frames(pose_feat, mask_indices):
    """
    pose_feat: (B, C, N, T)
    mask_indices: list of [tensor(F_b,), ...]  # F_b: số frame của mỗi sample cần fusion
    Return: list of (B, C, N, F_b)
    """
    outputs = []
    for b, idxs in enumerate(mask_indices):
        # idxs: (F_b,), pose_feat[b]: (C, N, T)
        sel = pose_feat[b, :, :, idxs]  # (C, N, F_b)
        outputs.append(sel)
    return outputs

In [14]:
def get_spatial_adjacency(num_node, edge):
    A = np.zeros((num_node, num_node))
    for i, j in edge:
        A[i, j] = 1
        A[j, i] = 1
    # Normalize
    Dl = np.sum(A, 0)
    Dn = np.zeros((num_node, num_node))
    for i in range(num_node):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    A_normalized = np.dot(A, Dn)
    return torch.tensor(A_normalized, dtype=torch.float32)

def get_body_spatial_graph():
    # 25 body keypoints (Mediapipe hoặc OpenPose định nghĩa)
    num_node = 25
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1), (1, 2), (2, 3), (3, 7),
        (0, 4), (4, 5), (5, 6), (6, 8),
        (9, 10), (11, 12), (11, 13), (13, 15), (15, 21), (15, 19), (15, 17),
        (17, 19), (11, 23), (12, 14), (14, 16), (16, 18), (16, 20), (16, 22),
        (18, 20), (12, 24), (23, 24)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

def get_left_hand_spatial_graph():
    # 21 left hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)
def get_right_hand_spatial_graph():
    # 21 right hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

In [15]:
PART_INFO = {
    'body':  (0, 25),
    'left':  (25, 46),
    'right': (46, 67)
}

In [16]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

def train_pose_spatial_part(part, num_joints, get_A_func, best_ckpt_file):
    print(f"\n--- Pretraining {part} ---")
    train_ds = PoseSpatialPartDataset(data_root, train_txt, label_map, part=part)
    val_ds = PoseSpatialPartDataset(data_root, val_txt, label_map, part=part)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Số mẫu train: {len(train_ds)}, val: {len(val_ds)}")
    print(f"Số batch train: {len(train_loader)}, val: {len(val_loader)}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    A = get_A_func().to(device)
    model = SpatialPoseEncoder(in_channels=3, num_joints=num_joints, num_classes=num_classes, A=A)
    if torch.cuda.device_count() > 1:
        print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, labels in tqdm(train_loader, desc=f"Train {part} Epoch {epoch+1}", leave=False):
            #print("Batch x.shape:", x.shape)
            x, labels = x.to(device), labels.to(device)
            logits, _ = model(x)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += x.size(0)
        if total > 0:
            train_acc = correct / total
            train_loss = total_loss / total
        else:
            train_acc = 0
            train_loss = 0
            print("WARNING: Không có sample nào trong batch train!")
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, labels in tqdm(val_loader, desc=f"Val {part} Epoch {epoch+1}", leave=False):
                x, labels = x.to(device), labels.to(device)
                logits, _ = model(x)
                loss = criterion(logits, labels)
                val_loss += loss.item() * x.size(0)
                preds = logits.argmax(1)
                val_correct += (preds == labels).sum().item()
                val_total += x.size(0)
        if val_total > 0:
            val_acc = val_correct / val_total
            val_loss = val_loss / val_total
        else:
            val_acc = 0
            val_loss = 0
            print("WARNING: Không có sample nào trong batch val!")

        print(f"[{part}] Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({"model": model.state_dict(), "label_map": label_map}, best_ckpt_file)
            print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

    print(f"Done {part}. Best val acc: {best_val_acc:.4f}")

In [17]:
data_root = "/kaggle/input/keypoint/keypoints"
train_txt = "/kaggle/working/train.txt"
val_txt = "/kaggle/working/val.txt"
test_txt = "/kaggle/working/test.txt"
batch_size = 32
num_workers = 4
num_epochs = 10
lr = 1e-3

label_map = build_label_map([train_txt, val_txt, test_txt])
num_classes = len(label_map)
with open("label_map_pose.json", "w") as f:
    json.dump(label_map, f)

In [18]:
print(get_body_spatial_graph().shape)       # (25, 25)
print(get_left_hand_spatial_graph().shape)  # (21, 21)
print(get_right_hand_spatial_graph().shape) # (21, 21)

torch.Size([25, 25])
torch.Size([21, 21])
torch.Size([21, 21])


In [19]:
train_pose_spatial_part('body',  num_joints=25, get_A_func=get_body_spatial_graph, best_ckpt_file="/kaggle/working/spatial_body_best.pth")
train_pose_spatial_part('left',  num_joints=21, get_A_func=get_left_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_left_best.pth")
train_pose_spatial_part('right', num_joints=21, get_A_func=get_right_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_right_best.pth")


--- Pretraining body ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


Train body Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


  return F.linear(input, self.weight, self.bias)
Val body Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 1/10 | Train Loss: 4.3657 Acc: 0.0614 | Val Loss: 3.3577 Acc: 0.0621
Best model saved at epoch 1, val_acc=0.0621


Train body Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 2/10 | Train Loss: 3.2634 Acc: 0.1469 | Val Loss: 3.4562 Acc: 0.0690
Best model saved at epoch 2, val_acc=0.0690


Train body Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 3/10 | Train Loss: 2.8637 Acc: 0.1930 | Val Loss: 3.2576 Acc: 0.1517
Best model saved at epoch 3, val_acc=0.1517


Train body Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 4/10 | Train Loss: 2.6601 Acc: 0.2237 | Val Loss: 3.2883 Acc: 0.1448


Train body Epoch 5:   7%|▋         | 1/15 [00:00<00:02,  6.71it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 5/10 | Train Loss: 2.5869 Acc: 0.2566 | Val Loss: 3.2973 Acc: 0.1172


Train body Epoch 6:   7%|▋         | 1/15 [00:00<00:02,  6.27it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 6/10 | Train Loss: 2.4920 Acc: 0.2544 | Val Loss: 3.2158 Acc: 0.1655
Best model saved at epoch 6, val_acc=0.1655


Train body Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 7/10 | Train Loss: 2.4000 Acc: 0.2719 | Val Loss: 3.3411 Acc: 0.1793
Best model saved at epoch 7, val_acc=0.1793


Train body Epoch 8:   7%|▋         | 1/15 [00:00<00:02,  6.77it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 8/10 | Train Loss: 2.3434 Acc: 0.3114 | Val Loss: 3.2132 Acc: 0.1724


Train body Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                               

[body] Epoch 9/10 | Train Loss: 2.1409 Acc: 0.3355 | Val Loss: 3.2905 Acc: 0.2207
Best model saved at epoch 9, val_acc=0.2207


Train body Epoch 10:   7%|▋         | 1/15 [00:00<00:02,  6.69it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                

[body] Epoch 10/10 | Train Loss: 2.1077 Acc: 0.3509 | Val Loss: 3.1557 Acc: 0.2000
Done body. Best val acc: 0.2207

--- Pretraining left ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


Train left Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 1/10 | Train Loss: 3.8045 Acc: 0.0658 | Val Loss: 3.3602 Acc: 0.0690
Best model saved at epoch 1, val_acc=0.0690


Train left Epoch 2:   7%|▋         | 1/15 [00:00<00:02,  6.38it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 2/10 | Train Loss: 3.4826 Acc: 0.0899 | Val Loss: 3.3589 Acc: 0.0828
Best model saved at epoch 2, val_acc=0.0828


Train left Epoch 3:   7%|▋         | 1/15 [00:00<00:02,  6.41it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 3/10 | Train Loss: 3.2417 Acc: 0.1140 | Val Loss: 3.1763 Acc: 0.0828


Train left Epoch 4:   7%|▋         | 1/15 [00:00<00:02,  6.56it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 4/10 | Train Loss: 2.9599 Acc: 0.1513 | Val Loss: 3.1629 Acc: 0.1931
Best model saved at epoch 4, val_acc=0.1931


Train left Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 5/10 | Train Loss: 2.9665 Acc: 0.1469 | Val Loss: 3.0735 Acc: 0.1724


Train left Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 6/10 | Train Loss: 2.8759 Acc: 0.1842 | Val Loss: 3.2153 Acc: 0.1448


Train left Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 7/10 | Train Loss: 2.8054 Acc: 0.1667 | Val Loss: 2.9121 Acc: 0.2138
Best model saved at epoch 7, val_acc=0.2138


Train left Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 8/10 | Train Loss: 2.8123 Acc: 0.1645 | Val Loss: 3.0690 Acc: 0.1862


Train left Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                               

[left] Epoch 9/10 | Train Loss: 2.7550 Acc: 0.1974 | Val Loss: 2.9618 Acc: 0.2276
Best model saved at epoch 9, val_acc=0.2276


Train left Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 10/10 | Train Loss: 2.6778 Acc: 0.1930 | Val Loss: 3.0419 Acc: 0.2552
Best model saved at epoch 10, val_acc=0.2552
Done left. Best val acc: 0.2552

--- Pretraining right ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


Train right Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 1/10 | Train Loss: 3.9503 Acc: 0.0548 | Val Loss: 3.6890 Acc: 0.0483
Best model saved at epoch 1, val_acc=0.0483


Train right Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 2/10 | Train Loss: 3.2540 Acc: 0.1535 | Val Loss: 4.0165 Acc: 0.0414


Train right Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 3/10 | Train Loss: 2.8737 Acc: 0.2193 | Val Loss: 3.3095 Acc: 0.1103
Best model saved at epoch 3, val_acc=0.1103


Train right Epoch 4:   7%|▋         | 1/15 [00:00<00:02,  5.70it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 4/10 | Train Loss: 2.6515 Acc: 0.2259 | Val Loss: 2.8566 Acc: 0.2276
Best model saved at epoch 4, val_acc=0.2276


Train right Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 5/10 | Train Loss: 2.4515 Acc: 0.3004 | Val Loss: 2.9559 Acc: 0.2207


Train right Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 6/10 | Train Loss: 2.3108 Acc: 0.3180 | Val Loss: 2.6523 Acc: 0.2759
Best model saved at epoch 6, val_acc=0.2759


Train right Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 7/10 | Train Loss: 2.1734 Acc: 0.3575 | Val Loss: 2.4173 Acc: 0.3724
Best model saved at epoch 7, val_acc=0.3724


Train right Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 8/10 | Train Loss: 2.0875 Acc: 0.3618 | Val Loss: 2.3656 Acc: 0.3448


Train right Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[right] Epoch 9/10 | Train Loss: 2.0713 Acc: 0.3794 | Val Loss: 2.4705 Acc: 0.4069
Best model saved at epoch 9, val_acc=0.4069


Train right Epoch 10:   7%|▋         | 1/15 [00:00<00:02,  6.41it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 10/10 | Train Loss: 2.0224 Acc: 0.3860 | Val Loss: 2.4860 Acc: 0.3241
Done right. Best val acc: 0.4069




# 3. WLASL Module

# 3.1. Fusion Module

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeformableAttention2D(nn.Module):
    """
    Deformable 2D attention:
    - query: (B, N, d_model)     # N: số keypoints
    - key, value: (B, C, H, W)   # feature map từ RGB backbone
    - ref_points: (B, N, 2)      # reference keypoint position (pixel, normalized [0,1])
    """
    def __init__(self, d_model, n_heads, n_points=4):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_points = n_points

        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Conv2d(d_model, d_model, 1)
        self.value_proj = nn.Conv2d(d_model, d_model, 1)

        # For each head, predict offset for each reference point (dx, dy) per query
        self.offset = nn.Linear(d_model, n_heads * n_points * 2)
        # For each head, predict attention weights for each sampled point
        self.attn_weight = nn.Linear(d_model, n_heads * n_points)

    def forward(self, query, key, value, ref_points):
        # query: (B, N, d_model)
        # key, value: (B, d_model, H, W)
        # ref_points: (B, N, 2) in [0, 1], normalized by feature size
        B, N, d_model = query.shape
        H, W = key.shape[2], key.shape[3]

        # Project query, key, value
        query_proj = self.query_proj(query)  # (B, N, d_model)
        key_proj = self.key_proj(key)        # (B, d_model, H, W)
        value_proj = self.value_proj(value)  # (B, d_model, H, W)

        # Predict offsets and attention weights
        offset = self.offset(query_proj)     # (B, N, n_heads*n_points*2)
        offset = offset.view(B, N, self.n_heads, self.n_points, 2)

        attn_weight = self.attn_weight(query_proj)  # (B, N, n_heads*n_points)
        attn_weight = attn_weight.view(B, N, self.n_heads, self.n_points)
        attn_weight = F.softmax(attn_weight, dim=-1)

        # Calculate sampling locations (normalized in [0, 1])
        # ref_points: (B, N, 2), offset: (B, N, n_heads, n_points, 2)
        # Sampling positions: (B, N, n_heads, n_points, 2)
        sampling_locations = ref_points.unsqueeze(2).unsqueeze(3) + offset / torch.tensor([W, H], device=offset.device)
        # Clamp to [0, 1]
        sampling_locations = sampling_locations.clamp(0, 1)

        # Prepare for grid_sample: scale to [-1, 1]
        # grid_sample expects normalized coords in [-1, 1]
        sampling_grid = sampling_locations.clone()
        sampling_grid[..., 0] = sampling_grid[..., 0] * 2 - 1
        sampling_grid[..., 1] = sampling_grid[..., 1] * 2 - 1

        # Sample value features at sampled locations
        # value_proj: (B, d_model, H, W)
        # For each query (B, N, n_heads, n_points, 2), sample (B, d_model, n_heads, n_points)
        sampled_feats = []
        for b in range(B):
            feats = []
            for n in range(N):
                grid = sampling_grid[b, n]  # (n_heads, n_points, 2)
                grid = grid.view(1, -1, 1, 2)  # (1, n_heads*n_points, 1, 2)
                # grid_sample: input (B, C, H, W), grid (B, out_H*out_W, 1, 2)
                sampled = F.grid_sample(
                    value_proj[b:b+1], grid, mode='bilinear', align_corners=True
                )  # (1, d_model, n_heads*n_points, 1)
                sampled = sampled.view(d_model, self.n_heads, self.n_points)
                feats.append(sampled)
            # feats: list of N tensors (d_model, n_heads, n_points) -> (N, d_model, n_heads, n_points)
            feats = torch.stack(feats, dim=0)
            sampled_feats.append(feats)
        # (B, N, d_model, n_heads, n_points)
        sampled_feats = torch.stack(sampled_feats, dim=0)

        # Weighted sum by attention weights
        # attn_weight: (B, N, n_heads, n_points)
        attn_weight = attn_weight.permute(0, 1, 3, 2)  # (B, N, n_points, n_heads)
        attn_weight = attn_weight.permute(0, 1, 3, 2)  # back to (B, N, n_heads, n_points)
        # (B, N, d_model, n_heads, n_points) * (B, N, 1, n_heads, n_points) -> sum over n_points
        out = (sampled_feats * attn_weight.unsqueeze(2)).sum(-1)  # (B, N, d_model, n_heads)
        out = out.mean(-1)  # mean over heads: (B, N, d_model)

        return out  # (B, N, d_model)


In [21]:
class PGFModule(nn.Module):
    """
    PGF Module: Multi-head deformable attention for fusion keypoint + RGB feature
    - query_feat: (B, N, d_model) -- e.g. STGCN output per joint
    - rgb_feat: (B, d_model, H, W) -- output feature map from CNN/ViT
    - ref_points: (B, N, 2) -- reference position (normalized [0,1]) for each joint
    Output: (B, N, d_model) fused feature per joint
    """
    def __init__(self, d_model, n_heads=4, n_points=4):
        super().__init__()
        self.attn = DeformableAttention2D(d_model, n_heads, n_points)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, query_feat, rgb_feat, ref_points):
        # query_feat: (B, N, d_model)
        # rgb_feat: (B, d_model, H, W)
        # ref_points: (B, N, 2) in [0, 1], normalized by feature map size
        fused = self.attn(query_feat, rgb_feat, rgb_feat, ref_points)  # (B, N, d_model)
        out = self.layer_norm(fused + query_feat)  # residual connection
        return out  # (B, N, d_model)

In [22]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=400):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (B, T, d_model)
        return x + self.pe[:, :x.size(1)]

class TemporalEncoder(nn.Module):
    """
    Temporal encoder as in UniSign: Transformer-based sequence encoder
    Input: (B, T, N, d)  --> (B, T, d)
    """
    def __init__(self, d_model, nhead=8, num_layers=2, dim_feedforward=512, dropout=0.1, max_len=400, pool='mean'):
        super().__init__()
        self.pool = pool
        self.position_encoding = PositionalEncoding(d_model, max_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        # x: (B, T, N, d)
        # UniSign: Thường tổng hợp theo joints (mean), còn nếu muốn giữ riêng body/left/right thì tách theo index rồi pooling từng phần
        x = x.mean(dim=2)  # (B, T, d), mean theo joints
        x = self.position_encoding(x)
        x = self.transformer(x)  # (B, T, d)
        if self.pool == 'mean':
            x = x.mean(dim=1)   # (B, d)
        elif self.pool == 'last':
            x = x[:, -1, :]     # (B, d)
        return x

In [23]:
import os
import numpy as np
import json

# Đường dẫn đến các thư mục
DATA_ROOT = '/kaggle/input'
CROPPED_HANDS_DIR = os.path.join(DATA_ROOT, 'image-mask', 'cropped_hands')
KEYPOINT_DIR = os.path.join(DATA_ROOT, 'keypoint', 'keypoints')
MASK_DIR = CROPPED_HANDS_DIR

SPLIT_TXT = {
    'train': '/kaggle/working/train.txt',
    'val': '/kaggle/working/val.txt',
    'test': '/kaggle/working/test.txt'
}

# Đọc label_map
with open('/kaggle/working/label_map.json') as f:
    label_map = json.load(f)

def build_meta_list(split_txt_path):
    meta_list = []
    with open(split_txt_path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        video_id, label_str = line.strip().split()
        label = label_map[label_str]  # Lấy label id từ label_map

        # Keypoint path
        keypoint_path = os.path.join(KEYPOINT_DIR, f"{video_id}_keypoint.npy")
        # Mask
        mask_left_path = os.path.join(MASK_DIR, f"{video_id}_mask_left.npy")
        mask_right_path = os.path.join(MASK_DIR, f"{video_id}_mask_right.npy")

        mask_left = np.load(mask_left_path)
        mask_right = np.load(mask_right_path)
        num_frames = len(mask_left)

        crop_left_files = []
        kp_j_left_files = []
        for idx in range(num_frames):
            if mask_left[idx] == 1:
                crop_left_files.append(os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_left.jpg"))
                kp_j_left_files.append(os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_left_kp.npy"))
        crop_right_files = []
        kp_j_right_files = []
        for idx in range(num_frames):
            if mask_right[idx] == 1:
                crop_right_files.append(os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_right.jpg"))
                kp_j_right_files.append(os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_right_kp.npy"))

        meta = {
            "video_id": video_id,
            "keypoint_path": keypoint_path,
            "label": label,
            "num_frames": num_frames,
            "mask_left_path": mask_left_path,
            "mask_right_path": mask_right_path,
            "crop_left_files": crop_left_files,
            "kp_j_left_files": kp_j_left_files,
            "crop_right_files": crop_right_files,
            "kp_j_right_files": kp_j_right_files
        }
        meta_list.append(meta)
    return meta_list

# Build cho các tập
for split, txt_path in SPLIT_TXT.items():
    meta_list = build_meta_list(txt_path)
    with open(f'/kaggle/working/{split}_meta.json', 'w') as f:
        json.dump(meta_list, f, indent=2)
    print(f"Saved {split}_meta.json with {len(meta_list)} samples")

Saved train_meta.json with 456 samples
Saved val_meta.json with 145 samples
Saved test_meta.json with 150 samples


In [24]:
import torch
import torch.nn as nn
import numpy as np

class ClassifySign(nn.Module):
    """
    Đầu vào:
        - keypoint: (B, T, N, 3)
        - rgb_imgs: (B, T, 3, H, W)
        - vision_encoder: backbone trích đặc trưng ảnh (pretrained)
        - stgcn_body, stgcn_left, stgcn_right: các backbone STGCN đã pretrain cho từng phần
        - PGFModule: khối fusion ảnh & keypoint cho left/right hand
        - temporal_encoder_body, temporal_encoder_hand: temporal encoder cho body, left, right
        - FCN: phân loại đầu ra với dropout để tránh overfit
    """

    def __init__(self, vision_encoder, stgcn_body, stgcn_left, stgcn_right,
                 pgf_module, temporal_encoder_body, temporal_encoder_hand, 
                 num_classes,
                 left_idx=21, right_idx=21, body_idx=25,
                 fcn_hidden=256, dropout=0.5):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.stgcn_body = stgcn_body
        self.stgcn_left = stgcn_left
        self.stgcn_right = stgcn_right
        self.pgf_module = pgf_module
        self.temporal_encoder_body = temporal_encoder_body
        self.temporal_encoder_hand = temporal_encoder_hand
        out_dim_body = temporal_encoder_body.output_dim
        out_dim_hand = temporal_encoder_hand.output_dim
        self.fcn = nn.Sequential(
            nn.Linear(out_dim_body + 2 * out_dim_hand, fcn_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(fcn_hidden, fcn_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(fcn_hidden, num_classes)
        )
        self.left_idx = left_idx
        self.right_idx = right_idx
        self.body_idx = body_idx

    def forward(self, keypoint, rgb_imgs, mask_left, mask_right, 
                kp_j_left, kp_j_right, rgb_left_imgs, rgb_right_imgs):
        """
        keypoint: (B, T, N, 3)
        rgb_imgs: (B, T, 3, H, W)
        mask_left, mask_right: (B, T)  # 1: dùng fusion, 0: bỏ qua fusion
        kp_j_left, kp_j_right: (B, T, K, 2)
        rgb_left_imgs, rgb_right_imgs: (B, T, 3, H, W)
        """
        B, T, N, C = keypoint.shape

        # 1. Chia keypoint thành 3 phần
        kp_body = keypoint[..., self.body_idx, :]      # (B, T, Nb, 3)
        kp_left = keypoint[..., self.left_idx, :]      # (B, T, Nl, 3)
        kp_right = keypoint[..., self.right_idx, :]    # (B, T, Nr, 3)

        # 2. Body: Chuỗi toàn bộ frames -> STGCN -> TemporalEncoder
        body_feat = self.stgcn_body(kp_body)           # (B, T, d)
        body_out = self.temporal_encoder_body(body_feat)  # (B, d_body)

        # 3. Xử lý left hand
        left_feats = []
        for t in range(T):
            frame_kp = kp_left[:, t, :, :]             # (B, Nl, 3)
            mask = mask_left[:, t]                     # (B,)
            img = rgb_left_imgs[:, t, :, :, :]         # (B, 3, H, W)
            ref_j = kp_j_left[:, t, :, :]              # (B, Kl, 2)
            stgcn_feat = self.stgcn_left(frame_kp)     # (B, d)
            stgcn_feat = stgcn_feat.unsqueeze(1)       # (B, 1, d)
            rgb_feat = self.vision_encoder(img)        # (B, d, H', W') hoặc (B, d)
            fused = []
            for b in range(B):
                if mask[b] == 1:
                    fused_feat = self.pgf_module(
                        stgcn_feat[b:b+1], rgb_feat[b:b+1], ref_j[b:b+1])
                    fused.append(fused_feat.squeeze(0))
                else:
                    fused.append(stgcn_feat[b, 0])
            left_feats.append(torch.stack(fused, dim=0))
        left_feats = torch.stack(left_feats, dim=1)
        left_out = self.temporal_encoder_hand(left_feats)  # (B, d_hand)

        # 4. Xử lý right hand (Tương tự)
        right_feats = []
        for t in range(T):
            frame_kp = kp_right[:, t, :, :]
            mask = mask_right[:, t]
            img = rgb_right_imgs[:, t, :, :, :]
            ref_j = kp_j_right[:, t, :, :]
            stgcn_feat = self.stgcn_right(frame_kp)
            stgcn_feat = stgcn_feat.unsqueeze(1)
            rgb_feat = self.vision_encoder(img)
            fused = []
            for b in range(B):
                if mask[b] == 1:
                    fused_feat = self.pgf_module(
                        stgcn_feat[b:b+1], rgb_feat[b:b+1], ref_j[b:b+1])
                    fused.append(fused_feat.squeeze(0))
                else:
                    fused.append(stgcn_feat[b, 0])
            right_feats.append(torch.stack(fused, dim=0))
        right_feats = torch.stack(right_feats, dim=1)
        right_out = self.temporal_encoder_hand(right_feats)  # (B, d_hand)

        # 5. Ghép 3 feature
        final_feat = torch.cat([body_out, left_out, right_out], dim=-1)  # (B, d_body+2*d_hand)
        logits = self.fcn(final_feat)
        return logits

In [25]:
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os

class SignDataset(Dataset):
    """
    Dataset cho mô hình ClassifySign với logic:
    - Ảnh crop/J chỉ lấy ở frame có mask==1, mapping đúng vị trí mask==1, còn lại (mask==0) điền zeros.
    """
    def __init__(
        self, meta_list, transform_crop=None,
        max_seq=64, crop_shape=(3, 112, 112), kp_num=21
    ):
        """
        meta_list: list các dict, mỗi dict gồm:
            'video_id'
            'keypoint_path'
            'label'
            'mask_left_path'
            'mask_right_path'
            'crop_left_files': list file ảnh left (đúng thứ tự mask==1)
            'crop_right_files': list file ảnh right (đúng thứ tự mask==1)
            'kp_j_left_files': list file J left (đúng thứ tự mask==1)
            'kp_j_right_files': list file J right (đúng thứ tự mask==1)
            'num_frames'
        transform_crop: transform ảnh crop
        crop_shape: shape cho ảnh zeros (C, H, W)
        kp_num: số keypoint bàn tay (J)
        """
        self.meta_list = meta_list
        self.transform_crop = transform_crop
        self.max_seq = max_seq
        self.crop_shape = crop_shape
        self.kp_num = kp_num

    def __len__(self):
        return len(self.meta_list)

    def __getitem__(self, idx):
        meta = self.meta_list[idx]
        T = meta['num_frames']
        label = meta['label']
        keypoint = np.load(meta['keypoint_path'])    # (T, N, 3)
        mask_left = np.load(meta['mask_left_path'])  # (T,)
        mask_right = np.load(meta['mask_right_path'])# (T,)

        # Chuẩn bị ảnh/J left/right (len đúng bằng số frame T)
        crop_left_imgs, kp_j_left = [], []
        left_idx = 0
        for t in range(T):
            if mask_left[t] == 1:
                img = Image.open(meta['crop_left_files'][left_idx]).convert("RGB")
                if self.transform_crop:
                    img = self.transform_crop(img)
                j = np.load(meta['kp_j_left_files'][left_idx])
                j = torch.tensor(j, dtype=torch.float32)
                left_idx += 1
            else:
                img = torch.zeros(self.crop_shape, dtype=torch.float32)
                j = torch.zeros(self.kp_num, 2, dtype=torch.float32)
            crop_left_imgs.append(img)
            kp_j_left.append(j)
        crop_left_imgs = torch.stack(crop_left_imgs, dim=0)   # (T, C, H, W)
        kp_j_left = torch.stack(kp_j_left, dim=0)             # (T, kp_num, 2)

        crop_right_imgs, kp_j_right = [], []
        right_idx = 0
        for t in range(T):
            if mask_right[t] == 1:
                img = Image.open(meta['crop_right_files'][right_idx]).convert("RGB")
                if self.transform_crop:
                    img = self.transform_crop(img)
                j = np.load(meta['kp_j_right_files'][right_idx])
                j = torch.tensor(j, dtype=torch.float32)
                right_idx += 1
            else:
                img = torch.zeros(self.crop_shape, dtype=torch.float32)
                j = torch.zeros(self.kp_num, 2, dtype=torch.float32)
            crop_right_imgs.append(img)
            kp_j_right.append(j)
        crop_right_imgs = torch.stack(crop_right_imgs, dim=0)
        kp_j_right = torch.stack(kp_j_right, dim=0)

        # Nếu cần pad/truncate về self.max_seq
        def pad_seq(seq, pad, max_len):
            if len(seq) >= max_len:
                return seq[:max_len]
            else:
                return seq + [pad]*(max_len-len(seq))
        
        T_cur = T
        crop_left_imgs = pad_seq(list(crop_left_imgs), torch.zeros(self.crop_shape), self.max_seq)
        crop_right_imgs = pad_seq(list(crop_right_imgs), torch.zeros(self.crop_shape), self.max_seq)
        kp_j_left = pad_seq(list(kp_j_left), torch.zeros(self.kp_num, 2), self.max_seq)
        kp_j_right = pad_seq(list(kp_j_right), torch.zeros(self.kp_num, 2), self.max_seq)
        mask_left = np.pad(mask_left, (0, max(self.max_seq - T_cur, 0)), 'constant')
        mask_right = np.pad(mask_right, (0, max(self.max_seq - T_cur, 0)), 'constant')
        keypoint = np.pad(keypoint, ((0, max(self.max_seq - T_cur, 0)), (0,0), (0,0)), 'constant')

        crop_left_imgs = torch.stack(crop_left_imgs, dim=0)
        crop_right_imgs = torch.stack(crop_right_imgs, dim=0)
        kp_j_left = torch.stack(kp_j_left, dim=0)
        kp_j_right = torch.stack(kp_j_right, dim=0)
        keypoint = torch.tensor(keypoint, dtype=torch.float32)
        mask_left = torch.tensor(mask_left, dtype=torch.float32)
        mask_right = torch.tensor(mask_right, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return {
            'keypoint': keypoint,             # (T, N, 3)
            'mask_left': mask_left,           # (T,)
            'mask_right': mask_right,         # (T,)
            'crop_left_imgs': crop_left_imgs, # (T, C, H, W)
            'crop_right_imgs': crop_right_imgs,
            'kp_j_left': kp_j_left,           # (T, kp_num, 2)
            'kp_j_right': kp_j_right,
            'label': label
        }

In [26]:
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def collate_fn(batch):
    """
    Batch là list các dict từ Dataset, ghép lại thành tensor.
    """
    batch_out = {}
    for k in batch[0].keys():
        if isinstance(batch[0][k], torch.Tensor):
            batch_out[k] = torch.stack([item[k] for item in batch], dim=0)
        else:
            # Nếu là int/float, convert sang tensor
            arr = [item[k] for item in batch]
            if isinstance(arr[0], (int, float)):
                batch_out[k] = torch.tensor(arr)
            else:
                batch_out[k] = arr
    return batch_out

def train_epoch(model, loader, optimizer, criterion, device="cuda"):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for batch in loader:
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].to(device)
        optimizer.zero_grad()
        logits = model(
            keypoint=batch['keypoint'],
            rgb_imgs=None,  # Nếu không dùng, truyền None hoặc bỏ field này khỏi model
            mask_left=batch['mask_left'],
            mask_right=batch['mask_right'],
            kp_j_left=batch['kp_j_left'],
            kp_j_right=batch['kp_j_right'],
            rgb_left_imgs=batch['crop_left_imgs'],
            rgb_right_imgs=batch['crop_right_imgs']
        )
        loss = criterion(logits, batch['label'])
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch['label'].size(0)
        pred = logits.argmax(dim=1)
        total_acc += (pred == batch['label']).sum().item()
        n += batch['label'].size(0)
    return total_loss / n, total_acc / n

@torch.no_grad()
def validate_epoch(model, loader, criterion, device="cuda"):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for batch in loader:
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].to(device)
        logits = model(
            keypoint=batch['keypoint'],
            rgb_imgs=None,
            mask_left=batch['mask_left'],
            mask_right=batch['mask_right'],
            kp_j_left=batch['kp_j_left'],
            kp_j_right=batch['kp_j_right'],
            rgb_left_imgs=batch['crop_left_imgs'],
            rgb_right_imgs=batch['crop_right_imgs']
        )
        loss = criterion(logits, batch['label'])
        total_loss += loss.item() * batch['label'].size(0)
        pred = logits.argmax(dim=1)
        total_acc += (pred == batch['label']).sum().item()
        n += batch['label'].size(0)
    return total_loss / n, total_acc / n

def save_checkpoint(model, optimizer, epoch, path, best_acc=None):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    if best_acc is not None:
        state['best_acc'] = best_acc
    torch.save(state, path)

def load_checkpoint(model, optimizer, path, device='cuda'):
    state = torch.load(path, map_location=device)
    model.load_state_dict(state['model'])
    if optimizer is not None:
        optimizer.load_state_dict(state['optimizer'])
    epoch = state.get('epoch', 0)
    best_acc = state.get('best_acc', None)
    return model, optimizer, epoch, best_acc

def adjust_learning_rate(optimizer, epoch, lr, step=10, decay=0.1):
    """ Step LR decay """
    lr_new = lr * (decay ** (epoch // step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_new
    return lr_new

In [27]:
import json
#from data.sign_dataset import SignDataset
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# Load meta_list
with open('/kaggle/working/train_meta.json') as f:
    train_meta_list = json.load(f)
with open('/kaggle/working/val_meta.json') as f:
    val_meta_list = json.load(f)

train_ds = SignDataset(meta_list=train_meta_list, transform_crop=transform, max_seq=64, crop_shape=(3,112,112), kp_num=21)
val_ds = SignDataset(meta_list=val_meta_list, transform_crop=transform, max_seq=64, crop_shape=(3,112,112), kp_num=21)

#from utils.train_utils import collate_fn
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn)