In [1]:
import os
import random
from collections import defaultdict

input_txt = '/kaggle/input/keypoint/videoid_label.txt'
out_train = '/kaggle/working/train.txt'
out_val = '/kaggle/working/val.txt'
out_test = '/kaggle/working/test.txt'

#os.makedirs(out_train, exist_ok=True)
#os.makedirs(out_val, exist_ok=True)
#os.makedirs(out_test, exist_ok=True)
# Đọc file và gom theo nhãn
label_dict = defaultdict(list)
with open(input_txt, 'r') as f:
    for line in f:
        line = line.strip()
        if not line: continue
        video_id, label = line.split()
        label_dict[label].append(video_id)

# Chia tỉ lệ 6:4:4 cho từng nhãn
train_lines, val_lines, test_lines = [], [], []
for label, vids in label_dict.items():
    vids = list(vids)
    random.shuffle(vids)
    n = len(vids)
    n_train = round(n * 0.6)
    n_val = round(n * 0.2)
    n_test = n - n_train - n_val
    train = vids[:n_train]
    val = vids[n_train:n_train+n_val]
    test = vids[n_train+n_val:]
    train_lines.extend([f"{vid} {label}\n" for vid in train])
    val_lines.extend([f"{vid} {label}\n" for vid in val])
    test_lines.extend([f"{vid} {label}\n" for vid in test])

# Shuffle lại từng tập để tránh cùng nhãn đứng liền nhau
random.shuffle(train_lines)
random.shuffle(val_lines)
random.shuffle(test_lines)

# Lưu file
with open(out_train, 'w') as f: f.writelines(train_lines)
with open(out_val, 'w') as f: f.writelines(val_lines)
with open(out_test, 'w') as f: f.writelines(test_lines)

print("Đã chia xong. Train:", len(train_lines), "Val:", len(val_lines), "Test:", len(test_lines))

Đã chia xong. Train: 456 Val: 145 Test: 150


# 1. Pretrain Vision encoder

In [2]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import torchvision
from tqdm import tqdm
import json

In [3]:
def build_label_map(txt_files):
    labels = set()
    for txt in txt_files:
        with open(txt, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    label = ' '.join(parts[1:]).strip()
                    labels.add(label)
    labels = sorted(labels)
    label_map = {lbl: idx for idx, lbl in enumerate(labels)}
    return label_map

In [4]:
data_root = "/kaggle/input/image-mask/cropped_hands"
train_txt = "/kaggle/working/train.txt"
valid_txt = "/kaggle/working/val.txt"
test_txt = "/kaggle/working/test.txt"
batch_size = 32
num_workers = 4
num_epochs = 6
lr = 1e-3
out_ckpt = "/kaggle/working/pretrained_hand_rgb.pth"
out_labelmap = "/kaggle/working/label_map.json"

label_map = build_label_map([train_txt, valid_txt, test_txt])
num_classes = len(label_map)
print("Số lớp:", num_classes)
print("Sample label_map:", dict(list(label_map.items())[:5]))

with open(out_labelmap, "w") as f:
    json.dump(label_map, f)

Số lớp: 32
Sample label_map: {'all': 0, 'before': 1, 'black': 2, 'book': 3, 'candy': 4}


In [5]:
# ==== 2. Dataset for cropped hand images ====
class HandImageDataset(Dataset):
    def __init__(self, data_root, list_file, label_map, transform=None):
        self.samples = []
        self.transform = transform
        with open(list_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 2:
                    video_id, label = parts
                    label_idx = label_map[label]
                    #img_dir = os.path.join(data_root, video_id)
                    # All *_left.jpg and *_right.jpg (can add filter for frame sampling if needed)
                    imgs = sorted(glob(os.path.join(data_root, f"{video_id}_*_left.jpg"))) + \
                           sorted(glob(os.path.join(data_root, f"{video_id}_*_right.jpg")))
                    for img_path in imgs:
                        if os.path.isfile(img_path):
                            self.samples.append((img_path, label_idx))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [6]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

train_ds = HandImageDataset(data_root, train_txt, label_map, transform=transform)
val_ds = HandImageDataset(data_root, valid_txt, label_map, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print(f"Số ảnh train: {len(train_ds)}, Số ảnh val: {len(val_ds)}")

Số ảnh train: 2298, Số ảnh val: 721


In [7]:
import torch
import torch.nn as nn
import torchvision

class VisionClassifier(nn.Module):
    def __init__(self, num_classes, out_channels=256):
        super().__init__()
        backbone = torchvision.models.efficientnet_b0(pretrained=True)
        # Reduce feature channels to out_channels (C=64) with a 1x1 conv
        self.features = nn.Sequential(
            backbone.features,                     # (B, 1280, H, W)
            nn.Conv2d(1280, out_channels, 1),      # (B, 256, H, W)
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(out_channels, num_classes)

    def forward(self, x, return_feature=False):
        feat = self.features(x)                              # (B, 256, H, W)
        pooled = self.pool(feat).view(x.size(0), -1)         # (B, 256)
        logits = self.classifier(pooled)                     # (B, num_classes)
        if return_feature:
            return logits, feat                              # (B, num_classes), (B, 256, H, W)
        else:
            return logits

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionClassifier(num_classes)
if torch.cuda.device_count() > 1:
    print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
    model = nn.DataParallel(model)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()



Using DataParallel with 2 GPUs


In [9]:
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}", leave=False):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    train_acc = correct / total
    train_loss = total_loss / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Valid", leave=False):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(1)
            val_correct += (preds == labels).sum().item()
            val_total += imgs.size(0)
    val_acc = val_correct / val_total
    val_loss = val_loss / val_total

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model": model.state_dict(),
            "label_map": label_map
        }, out_ckpt)
        print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

print("Done. Best val acc:", best_val_acc)

                                                              

Epoch 1/6 | Train Loss: 3.0880 Acc: 0.1819 | Val Loss: 2.9932 Acc: 0.2205
Best model saved at epoch 1, val_acc=0.2205


                                                              

Epoch 2/6 | Train Loss: 2.5333 Acc: 0.3216 | Val Loss: 2.9418 Acc: 0.2677
Best model saved at epoch 2, val_acc=0.2677


                                                              

Epoch 3/6 | Train Loss: 2.1871 Acc: 0.4038 | Val Loss: 2.7004 Acc: 0.3010
Best model saved at epoch 3, val_acc=0.3010


                                                              

Epoch 4/6 | Train Loss: 1.8409 Acc: 0.5009 | Val Loss: 2.9763 Acc: 0.2732


                                                              

Epoch 5/6 | Train Loss: 1.5737 Acc: 0.5688 | Val Loss: 3.0074 Acc: 0.3065
Best model saved at epoch 5, val_acc=0.3065


                                                              

Epoch 6/6 | Train Loss: 1.4008 Acc: 0.6097 | Val Loss: 3.0779 Acc: 0.3315
Best model saved at epoch 6, val_acc=0.3315
Done. Best val acc: 0.3314840499306519


# 2. pretrain STGCN

In [10]:
import numpy as np
import torch
from torch.utils.data import Dataset

class PoseSpatialPartDataset(Dataset):
    def __init__(self, data_root, txt_file, label_map, part='body'):
        self.samples = []
        self.label_map = label_map
        self.part = part
        with open(txt_file) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 2:
                    print(f"WARNING: dòng bị lỗi format: {line}")
                    continue
                npy_id = parts[0]
                label = ' '.join(parts[1:]).strip()
                if label not in label_map:
                    print(f"WARNING: label '{label}' chưa có trong label_map!")
                    continue
                npy_path = f"{data_root}/{npy_id}_keypoint.npy"
                self.samples.append((npy_path, int(label_map[label])))
        # Define keypoint slices for each part
        if part == 'body':
            self.idx_start, self.idx_end = 0, 25
        elif part == 'left':
            self.idx_start, self.idx_end = 25, 46
        elif part == 'right':
            self.idx_start, self.idx_end = 46, 67
        else:
            raise ValueError("Unknown part: " + part)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npy_path, label = self.samples[idx]
        keypoints = np.load(npy_path)  # (T, 67, 3) expected
        # Auto fix shape if needed
        if keypoints.shape[-2:] == (67, 3):
            pass
        elif keypoints.shape[0] == 67 and keypoints.shape[1] == 3:
            keypoints = np.transpose(keypoints, (2, 0, 1))
        elif keypoints.shape[1] == 3 and keypoints.shape[2] == 67:
            keypoints = np.transpose(keypoints, (0, 2, 1))
        else:
            raise RuntimeError(f"Unrecognized keypoints shape: {keypoints.shape}")
        part_kp = keypoints[:, self.idx_start:self.idx_end, :]  # (T, N, 3)
        # Lấy frame giữa
        t = part_kp.shape[0] // 2
        part_kp = part_kp[t]  # (N, 3)
        if idx == 0:
            print(f"Dataset part_kp.shape: {part_kp.shape}")
        return torch.tensor(part_kp, dtype=torch.float32), label

In [11]:
import torch
import torch.nn as nn

class SpatialGCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, A):
        super().__init__()
        self.register_buffer('A', A)
        self.fc = nn.Linear(in_channels, out_channels)
        self.bn = nn.BatchNorm1d(out_channels)
        if in_channels != out_channels:
            self.residual = nn.Linear(in_channels, out_channels)
        else:
            self.residual = nn.Identity()

    def forward(self, x):  # x: (B, N, in_channels)
        # Đảm bảo self.A cùng device với x
        A = self.A
        if A.device != x.device:
            A = A.to(x.device)

        res = self.residual(x)                   # (B, N, out_channels)
        h = self.fc(x)                           # (B, N, out_channels)
        h = h.permute(0, 2, 1)                   # (B, out_channels, N)
        h = torch.matmul(h, A)                   # (B, out_channels, N)
        # BatchNorm1d expects (B, C, N) => reshape lại cho an toàn
        h = self.bn(h)                           # (B, out_channels, N)
        h = h.permute(0, 2, 1)                   # (B, N, out_channels)
        out = torch.relu(h + res)
        return out

In [12]:
import torch
import torch.nn as nn

class SpatialPoseEncoder(nn.Module):
    def __init__(self, in_channels, num_joints, num_classes, A, hid_dim=128, out_dim=256, input_dim=3):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, in_channels)  # (B, N, 3) -> (B, N, in_channels)
        self.gcn1 = SpatialGCNLayer(in_channels, hid_dim, A)
        self.gcn2 = SpatialGCNLayer(hid_dim, out_dim, A)
        self.classifier = nn.Linear(out_dim * num_joints, num_classes)

    def forward(self, x):  # x: (B, N, 3)
        B, N, C = x.shape
        assert N == self.gcn1.A.shape[0], f"x.shape={x.shape}, A.shape={self.gcn1.A.shape}"
        x = self.input_proj(x)         # (B, N, in_channels)
        h = self.gcn1(x)               # (B, N, hid_dim)
        h = self.gcn2(h)               # (B, N, out_dim)
        h = h.permute(0, 2, 1)         # (B, out_dim, N)
        logits = self.classifier(h.flatten(1))  # (B, num_classes)
        return logits, h

In [13]:
def get_spatial_adjacency(num_node, edge):
    A = np.zeros((num_node, num_node))
    for i, j in edge:
        A[i, j] = 1
        A[j, i] = 1
    # Normalize
    Dl = np.sum(A, 0)
    Dn = np.zeros((num_node, num_node))
    for i in range(num_node):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    A_normalized = np.dot(A, Dn)
    return torch.tensor(A_normalized, dtype=torch.float32)

def get_body_spatial_graph():
    # 25 body keypoints (Mediapipe hoặc OpenPose định nghĩa)
    num_node = 25
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1), (1, 2), (2, 3), (3, 7),
        (0, 4), (4, 5), (5, 6), (6, 8),
        (9, 10), (11, 12), (11, 13), (13, 15), (15, 21), (15, 19), (15, 17),
        (17, 19), (11, 23), (12, 14), (14, 16), (16, 18), (16, 20), (16, 22),
        (18, 20), (12, 24), (23, 24)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

def get_left_hand_spatial_graph():
    # 21 left hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)
def get_right_hand_spatial_graph():
    # 21 right hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

In [14]:
PART_INFO = {
    'body':  (0, 25),
    'left':  (25, 46),
    'right': (46, 67)
}

In [15]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

def train_pose_spatial_part(part, num_joints, get_A_func, best_ckpt_file):
    print(f"\n--- Pretraining {part} ---")
    train_ds = PoseSpatialPartDataset(data_root, train_txt, label_map, part=part)
    val_ds = PoseSpatialPartDataset(data_root, val_txt, label_map, part=part)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Số mẫu train: {len(train_ds)}, val: {len(val_ds)}")
    print(f"Số batch train: {len(train_loader)}, val: {len(val_loader)}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    A = get_A_func().to(device)
    model = SpatialPoseEncoder(in_channels=3, num_joints=num_joints, num_classes=num_classes, A=A)
    if torch.cuda.device_count() > 1:
        print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, labels in tqdm(train_loader, desc=f"Train {part} Epoch {epoch+1}", leave=False):
            #print("Batch x.shape:", x.shape)
            x, labels = x.to(device), labels.to(device)
            logits, _ = model(x)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += x.size(0)
        if total > 0:
            train_acc = correct / total
            train_loss = total_loss / total
        else:
            train_acc = 0
            train_loss = 0
            print("WARNING: Không có sample nào trong batch train!")
        
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, labels in tqdm(val_loader, desc=f"Val {part} Epoch {epoch+1}", leave=False):
                x, labels = x.to(device), labels.to(device)
                logits, _ = model(x)
                loss = criterion(logits, labels)
                val_loss += loss.item() * x.size(0)
                preds = logits.argmax(1)
                val_correct += (preds == labels).sum().item()
                val_total += x.size(0)
        if val_total > 0:
            val_acc = val_correct / val_total
            val_loss = val_loss / val_total
        else:
            val_acc = 0
            val_loss = 0
            print("WARNING: Không có sample nào trong batch val!")

        print(f"[{part}] Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({"model": model.state_dict(), "label_map": label_map}, best_ckpt_file)
            print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

    print(f"Done {part}. Best val acc: {best_val_acc:.4f}")

In [16]:
data_root = "/kaggle/input/keypoint/keypoints"
train_txt = "/kaggle/working/train.txt"
val_txt = "/kaggle/working/val.txt"
test_txt = "/kaggle/working/test.txt"
batch_size = 32
num_workers = 4
num_epochs = 10
lr = 1e-3

label_map = build_label_map([train_txt, val_txt, test_txt])
num_classes = len(label_map)
with open("label_map_pose.json", "w") as f:
    json.dump(label_map, f)

In [17]:
print(get_body_spatial_graph().shape)       # (25, 25)
print(get_left_hand_spatial_graph().shape)  # (21, 21)
print(get_right_hand_spatial_graph().shape) # (21, 21)

torch.Size([25, 25])
torch.Size([21, 21])
torch.Size([21, 21])


In [18]:
train_pose_spatial_part('body',  num_joints=25, get_A_func=get_body_spatial_graph, best_ckpt_file="/kaggle/working/spatial_body_best.pth")
train_pose_spatial_part('left',  num_joints=21, get_A_func=get_left_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_left_best.pth")
train_pose_spatial_part('right', num_joints=21, get_A_func=get_right_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_right_best.pth")


--- Pretraining body ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


  return F.linear(input, self.weight, self.bias)
Train body Epoch 1:   7%|▋         | 1/15 [00:00<00:04,  3.40it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 1/10 | Train Loss: 3.9770 Acc: 0.0702 | Val Loss: 3.4690 Acc: 0.0552
Best model saved at epoch 1, val_acc=0.0552


Train body Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 2/10 | Train Loss: 2.9187 Acc: 0.1996 | Val Loss: 3.4556 Acc: 0.0759
Best model saved at epoch 2, val_acc=0.0759


Train body Epoch 3:   7%|▋         | 1/15 [00:00<00:02,  6.16it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 3/10 | Train Loss: 2.5772 Acc: 0.2171 | Val Loss: 3.7884 Acc: 0.0897
Best model saved at epoch 3, val_acc=0.0897


Train body Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 4/10 | Train Loss: 2.3808 Acc: 0.2566 | Val Loss: 3.6181 Acc: 0.1034
Best model saved at epoch 4, val_acc=0.1034


Train body Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 5/10 | Train Loss: 2.2740 Acc: 0.2982 | Val Loss: 3.2567 Acc: 0.1517
Best model saved at epoch 5, val_acc=0.1517


Train body Epoch 6:   7%|▋         | 1/15 [00:00<00:02,  6.29it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 6/10 | Train Loss: 2.1701 Acc: 0.3333 | Val Loss: 2.9016 Acc: 0.2207
Best model saved at epoch 6, val_acc=0.2207


Train body Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 7/10 | Train Loss: 2.0709 Acc: 0.3553 | Val Loss: 3.2621 Acc: 0.1793


Train body Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 8/10 | Train Loss: 1.9551 Acc: 0.3794 | Val Loss: 3.2639 Acc: 0.2000


Train body Epoch 9:   7%|▋         | 1/15 [00:00<00:02,  6.58it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                               

[body] Epoch 9/10 | Train Loss: 1.9683 Acc: 0.3860 | Val Loss: 3.0793 Acc: 0.2138


Train body Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (25, 3)


Val body Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (25, 3)


                                                                

[body] Epoch 10/10 | Train Loss: 1.8564 Acc: 0.4079 | Val Loss: 3.0937 Acc: 0.2552
Best model saved at epoch 10, val_acc=0.2552
Done body. Best val acc: 0.2552

--- Pretraining left ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


Train left Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 1/10 | Train Loss: 3.9994 Acc: 0.0636 | Val Loss: 3.3982 Acc: 0.0414
Best model saved at epoch 1, val_acc=0.0414


Train left Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 2/10 | Train Loss: 2.9818 Acc: 0.1338 | Val Loss: 3.3133 Acc: 0.0276


Train left Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 3/10 | Train Loss: 2.8101 Acc: 0.1557 | Val Loss: 3.2798 Acc: 0.0828
Best model saved at epoch 3, val_acc=0.0828


Train left Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 4/10 | Train Loss: 2.6940 Acc: 0.2039 | Val Loss: 3.0431 Acc: 0.1655
Best model saved at epoch 4, val_acc=0.1655


Train left Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 5/10 | Train Loss: 2.6751 Acc: 0.1996 | Val Loss: 3.0014 Acc: 0.1379


Train left Epoch 6:   7%|▋         | 1/15 [00:00<00:02,  6.59it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 6/10 | Train Loss: 2.6486 Acc: 0.1930 | Val Loss: 2.9682 Acc: 0.1793
Best model saved at epoch 6, val_acc=0.1793


Train left Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 7/10 | Train Loss: 2.5875 Acc: 0.2061 | Val Loss: 3.0088 Acc: 0.1862
Best model saved at epoch 7, val_acc=0.1862


Train left Epoch 8:   7%|▋         | 1/15 [00:00<00:02,  6.65it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 8/10 | Train Loss: 2.5227 Acc: 0.2368 | Val Loss: 2.9322 Acc: 0.2276
Best model saved at epoch 8, val_acc=0.2276


Train left Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                               

[left] Epoch 9/10 | Train Loss: 2.5406 Acc: 0.2675 | Val Loss: 3.0394 Acc: 0.1586


Train left Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val left Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[left] Epoch 10/10 | Train Loss: 2.4836 Acc: 0.2303 | Val Loss: 2.9927 Acc: 0.2276
Done left. Best val acc: 0.2276

--- Pretraining right ---
Số mẫu train: 456, val: 145
Số batch train: 15, val: 5
Using DataParallel with 2 GPUs


Train right Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 1/10 | Train Loss: 3.9868 Acc: 0.0768 | Val Loss: 3.4628 Acc: 0.0276
Best model saved at epoch 1, val_acc=0.0276


Train right Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 2:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 2/10 | Train Loss: 3.0264 Acc: 0.1425 | Val Loss: 3.4256 Acc: 0.0414
Best model saved at epoch 2, val_acc=0.0414


Train right Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 3:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 3/10 | Train Loss: 2.7901 Acc: 0.1645 | Val Loss: 3.2728 Acc: 0.1034
Best model saved at epoch 3, val_acc=0.1034


Train right Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 4:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 4/10 | Train Loss: 2.5561 Acc: 0.2917 | Val Loss: 3.1556 Acc: 0.1241
Best model saved at epoch 4, val_acc=0.1241


Train right Epoch 5:   7%|▋         | 1/15 [00:00<00:02,  6.65it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 5/10 | Train Loss: 2.5337 Acc: 0.2675 | Val Loss: 2.9142 Acc: 0.2345
Best model saved at epoch 5, val_acc=0.2345


Train right Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 6:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 6/10 | Train Loss: 2.4139 Acc: 0.2961 | Val Loss: 2.8280 Acc: 0.2759
Best model saved at epoch 6, val_acc=0.2759


Train right Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 7:   0%|          | 0/5 [00:00<?, ?it/s]            

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 7/10 | Train Loss: 2.3448 Acc: 0.2961 | Val Loss: 2.9001 Acc: 0.2897
Best model saved at epoch 7, val_acc=0.2897


Train right Epoch 8:   7%|▋         | 1/15 [00:00<00:02,  6.24it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 8:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 8/10 | Train Loss: 2.1588 Acc: 0.3794 | Val Loss: 3.0839 Acc: 0.2483


Train right Epoch 9:   7%|▋         | 1/15 [00:00<00:02,  6.15it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 9:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                

[right] Epoch 9/10 | Train Loss: 2.0740 Acc: 0.3640 | Val Loss: 2.8324 Acc: 0.2828


Train right Epoch 10:   7%|▋         | 1/15 [00:00<00:02,  6.29it/s]

Dataset part_kp.shape: (21, 3)


Val right Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s]           

Dataset part_kp.shape: (21, 3)


                                                                 

[right] Epoch 10/10 | Train Loss: 2.1001 Acc: 0.3904 | Val Loss: 2.9052 Acc: 0.2483
Done right. Best val acc: 0.2897




# 3. WLASL Module

# 3.1. Fusion Module

# sửa lại mask

In [19]:
import os
import numpy as np
import glob
import cv2

def get_n_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return n_frames
    
def get_sorted_indexes(crop_dir, video_id, hand):
    # hand: 'left' hoặc 'right'
    pattern = os.path.join(crop_dir, f"{video_id}_frame*_{hand}.jpg")
    files = glob.glob(pattern)
    idxs = []
    for f in files:
        # Lấy index từ tên file: ..._frame<idx>_left.jpg hoặc ..._frame<idx>_right.jpg
        basename = os.path.basename(f)
        idx = int(basename.split("_frame")[1].split(f"_{hand}")[0])
        idxs.append(idx)
    idxs = sorted(list(set(idxs)))
    return idxs

def make_mask_for_hand(n_frames, crop_idxs, out_len=64):
    """
    n_frames: số frame gốc của video
    crop_idxs: list index frame gốc có ảnh crop tay
    out_len: số frame muốn nội suy (thường = 64)
    Đảm bảo số mask[i]=1 đúng bằng số crop_idxs (số ảnh crop)
    """
    idxs_interp = np.linspace(0, n_frames - 1, out_len)
    mask = np.zeros(out_len, dtype=np.uint8)
    used_positions = set()
    for crop_idx in crop_idxs:
        # Tìm vị trí chưa được gán nào gần nhất crop_idx
        distances = np.abs(idxs_interp - crop_idx)
        for i in np.argsort(distances):
            if i not in used_positions:
                mask[i] = 1
                used_positions.add(i)
                break
    return mask

def save_masks_for_video(video_id, crop_dir, video_dir, out_dir, out_len=64):
    video_path = os.path.join(video_dir, f"{video_id}.mp4")
    n_frames = get_n_frames(video_path)
    idxs_left = get_sorted_indexes(crop_dir, video_id, 'left')
    idxs_right = get_sorted_indexes(crop_dir, video_id, 'right')

    mask_left = make_mask_for_hand(n_frames, idxs_left, out_len)
    mask_right = make_mask_for_hand(n_frames, idxs_right, out_len)

    np.save(os.path.join(out_dir, f"{video_id}_mask_left.npy"), mask_left)
    np.save(os.path.join(out_dir, f"{video_id}_mask_right.npy"), mask_right)
    print(f"Saved masks for {video_id}: left sum={mask_left.sum()}, right sum={mask_right.sum()} (left imgs={len(idxs_left)}, right imgs={len(idxs_right)})")

CROPPED_HANDS_DIR = "/kaggle/input/image-mask/cropped_hands"
VIDEO_DIR = "/kaggle/input/wlasl-processed/WLASL/videos"
OUT_DIR = '/kaggle/working/mask'

os.makedirs(OUT_DIR, exist_ok=True)

def get_video_ids_from_txt(txt_path):
    video_ids = []
    with open(txt_path, "r") as f:
        for line in f:
            video_id = line.strip().split()[0]
            video_ids.append(video_id)
    return video_ids

# Ví dụ xử lý cho train
txt_path = "/kaggle/input/keypoint/videoid_label.txt"
video_ids = get_video_ids_from_txt(txt_path)

for video_id in video_ids:
    save_masks_for_video(video_id, CROPPED_HANDS_DIR, VIDEO_DIR, OUT_DIR, out_len=64)

Saved masks for 69241: left sum=4, right sum=3 (left imgs=4, right imgs=3)
Saved masks for 65225: left sum=3, right sum=3 (left imgs=3, right imgs=3)
Saved masks for 68011: left sum=4, right sum=3 (left imgs=4, right imgs=3)
Saved masks for 68208: left sum=3, right sum=3 (left imgs=3, right imgs=3)
Saved masks for 68012: left sum=2, right sum=3 (left imgs=2, right imgs=3)
Saved masks for 70212: left sum=0, right sum=0 (left imgs=0, right imgs=0)
Saved masks for 70266: left sum=1, right sum=1 (left imgs=1, right imgs=1)
Saved masks for 07085: left sum=2, right sum=3 (left imgs=2, right imgs=3)
Saved masks for 07086: left sum=3, right sum=4 (left imgs=3, right imgs=4)
Saved masks for 07087: left sum=4, right sum=4 (left imgs=4, right imgs=4)
Saved masks for 07069: left sum=3, right sum=2 (left imgs=3, right imgs=2)
Saved masks for 07088: left sum=3, right sum=4 (left imgs=3, right imgs=4)
Saved masks for 07089: left sum=4, right sum=3 (left imgs=4, right imgs=3)
Saved masks for 07090: le

In [20]:
import numpy as np

kp = np.load('/kaggle/working/mask/63678_mask_right.npy')
print(kp)

[0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0]


# Continue ... 

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeformableAttention2D(nn.Module):
    """
    Deformable 2D attention:
    - query: (B, N, d_model)     # N: số keypoints
    - key, value: (B, C, H, W)   # feature map từ RGB backbone
    - ref_points: (B, N, 2)      # reference keypoint position (pixel, normalized [0,1])
    """
    def __init__(self, d_model, n_heads, n_points=4):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_points = n_points

        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Conv2d(d_model, d_model, 1)
        self.value_proj = nn.Conv2d(d_model, d_model, 1)

        # For each head, predict offset for each reference point (dx, dy) per query
        self.offset = nn.Linear(d_model, n_heads * n_points * 2)
        # For each head, predict attention weights for each sampled point
        self.attn_weight = nn.Linear(d_model, n_heads * n_points)

    def forward(self, query, key, value, ref_points):
        # query: (B, N, d_model)
        # key, value: (B, d_model, H, W)
        # ref_points: (B, N, 2) in [0, 1], normalized by feature size
        B, N, d_model = query.shape
        H, W = key.shape[2], key.shape[3]

        # Project query, key, value
        query_proj = self.query_proj(query)  # (B, N, d_model)
        key_proj = self.key_proj(key)        # (B, d_model, H, W)
        value_proj = self.value_proj(value)  # (B, d_model, H, W)

        # Predict offsets and attention weights
        offset = self.offset(query_proj)     # (B, N, n_heads*n_points*2)
        offset = offset.view(B, N, self.n_heads, self.n_points, 2)

        attn_weight = self.attn_weight(query_proj)  # (B, N, n_heads*n_points)
        attn_weight = attn_weight.view(B, N, self.n_heads, self.n_points)
        attn_weight = F.softmax(attn_weight, dim=-1)

        # Calculate sampling locations (normalized in [0, 1])
        # ref_points: (B, N, 2), offset: (B, N, n_heads, n_points, 2)
        # Sampling positions: (B, N, n_heads, n_points, 2)
        sampling_locations = ref_points.unsqueeze(2).unsqueeze(3) + offset / torch.tensor([W, H], device=offset.device)
        # Clamp to [0, 1]
        sampling_locations = sampling_locations.clamp(0, 1)

        # Prepare for grid_sample: scale to [-1, 1]
        # grid_sample expects normalized coords in [-1, 1]
        sampling_grid = sampling_locations.clone()
        sampling_grid[..., 0] = sampling_grid[..., 0] * 2 - 1
        sampling_grid[..., 1] = sampling_grid[..., 1] * 2 - 1

        # Sample value features at sampled locations
        # value_proj: (B, d_model, H, W)
        # For each query (B, N, n_heads, n_points, 2), sample (B, d_model, n_heads, n_points)
        sampled_feats = []
        for b in range(B):
            feats = []
            for n in range(N):
                grid = sampling_grid[b, n]  # (n_heads, n_points, 2)
                grid = grid.view(1, -1, 1, 2)  # (1, n_heads*n_points, 1, 2)
                # grid_sample: input (B, C, H, W), grid (B, out_H*out_W, 1, 2)
                sampled = F.grid_sample(
                    value_proj[b:b+1], grid, mode='bilinear', align_corners=True
                )  # (1, d_model, n_heads*n_points, 1)
                sampled = sampled.view(d_model, self.n_heads, self.n_points)
                feats.append(sampled)
            # feats: list of N tensors (d_model, n_heads, n_points) -> (N, d_model, n_heads, n_points)
            feats = torch.stack(feats, dim=0)
            sampled_feats.append(feats)
        # (B, N, d_model, n_heads, n_points)
        sampled_feats = torch.stack(sampled_feats, dim=0)

        # Weighted sum by attention weights
        # attn_weight: (B, N, n_heads, n_points)
        attn_weight = attn_weight.permute(0, 1, 3, 2)  # (B, N, n_points, n_heads)
        attn_weight = attn_weight.permute(0, 1, 3, 2)  # back to (B, N, n_heads, n_points)
        # (B, N, d_model, n_heads, n_points) * (B, N, 1, n_heads, n_points) -> sum over n_points
        out = (sampled_feats * attn_weight.unsqueeze(2)).sum(-1)  # (B, N, d_model, n_heads)
        out = out.mean(-1)  # mean over heads: (B, N, d_model)

        return out  # (B, N, d_model)


In [22]:
class PGFModule(nn.Module):
    """
    PGF Module: Multi-head deformable attention for fusion keypoint + RGB feature
    - query_feat: (B, N, d_model) -- e.g. STGCN output per joint
    - rgb_feat: (B, d_model, H, W) -- output feature map from CNN/ViT
    - ref_points: (B, N, 2) -- reference position (normalized [0,1]) for each joint
    Output: (B, N, d_model) fused feature per joint
    """
    def __init__(self, d_model, n_heads=4, n_points=4):
        super().__init__()
        self.attn = DeformableAttention2D(d_model, n_heads, n_points)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, query_feat, rgb_feat, ref_points):
        # query_feat: (B, N, d_model)
        # rgb_feat: (B, d_model, H, W)
        # ref_points: (B, N, 2) in [0, 1], normalized by feature map size
        fused = self.attn(query_feat, rgb_feat, rgb_feat, ref_points)  # (B, N, d_model)
        out = self.layer_norm(fused + query_feat)  # residual connection
        return out  # (B, N, d_model)

In [23]:
import torch
import torch.nn as nn

class TemporalEncoder(nn.Module):
    """
    Temporal encoder cho input dạng (B, T, C, N).
    Học đặc trưng temporal cho từng joint (giữ spatial/joint riêng).
    Output: (B, C, N) (nếu pool theo T), hoặc (B, T, C, N) nếu pool=None.
    """
    def __init__(self, c_model, n_joints, nhead=8, num_layers=2, dim_feedforward=512, dropout=0.1, pool='mean'):
        """
        c_model: số channels đầu vào cho mỗi joint (ví dụ C).
        n_joints: số joint (N).
        pool: 'mean' (mean theo T), 'last' (lấy frame cuối), hoặc None (giữ nguyên (B, T, C, N)).
        """
        super().__init__()
        self.n_joints = n_joints
        self.c_model = c_model
        self.pool = pool
        # Một transformer encoder riêng cho mỗi joint
        self.transformer_list = nn.ModuleList([
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=c_model, nhead=nhead, 
                                          dim_feedforward=dim_feedforward, 
                                          batch_first=True, dropout=dropout),
                num_layers=num_layers
            ) for _ in range(n_joints)
        ])
        self.output_dim = c_model * n_joints if pool is None else c_model * n_joints

    def forward(self, x):
        """
        x: (B, T, C, N)
        Output:
            - (B, C, N) nếu pool
            - (B, T, C, N) nếu pool=None
        """
        B, T, C, N = x.shape
        # Chuyển sang list các tensor (B, T, C) cho từng joint
        outs = []
        for j in range(N):
            x_j = x[..., j]            # (B, T, C)
            # Transformer yêu cầu (B, T, C_model)
            out_j = self.transformer_list[j](x_j)   # (B, T, C)
            if self.pool == 'mean':
                out_j = out_j.mean(dim=1)           # (B, C)
            elif self.pool == 'last':
                out_j = out_j[:, -1, :]             # (B, C)
            # else giữ (B, T, C)
            outs.append(out_j)
        if self.pool is not None:
            # outs: list (B, C), ghép thành (B, C, N)
            out = torch.stack(outs, dim=-1)         # (B, C, N)
        else:
            # outs: list (B, T, C), ghép thành (B, T, C, N)
            out = torch.stack(outs, dim=-1)         # (B, T, C, N)
        return out

In [24]:
import os
import numpy as np
import json
import glob
import re

DATA_ROOT = '/kaggle/input'
CROPPED_HANDS_DIR = os.path.join(DATA_ROOT, 'image-mask', 'cropped_hands')
KEYPOINT_DIR = os.path.join(DATA_ROOT, 'keypoint', 'keypoints')
MASK_DIR = '/kaggle/working/mask'

SPLIT_TXT = {
    'train': '/kaggle/working/train.txt',
    'val': '/kaggle/working/val.txt',
    'test': '/kaggle/working/test.txt'
}

# Đọc label_map
with open('/kaggle/working/label_map_pose.json') as f:
    label_map = json.load(f)

def get_sorted_crop_idxs(crop_dir, video_id, hand):
    pattern = os.path.join(crop_dir, f"{video_id}_frame*_{hand}.jpg")
    files = glob.glob(pattern)
    idxs = []
    for f in files:
        basename = os.path.basename(f)
        match = re.search(rf"{video_id}_frame(\d+)_{hand}\.jpg", basename)
        if match:
            idx = int(match.group(1))
            idxs.append(idx)
    idxs = sorted(list(set(idxs)))
    return idxs

def build_meta_list(split_txt_path):
    meta_list = []
    with open(split_txt_path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        video_id, label_str = line.strip().split()
        label = label_map[label_str]

        keypoint_path = os.path.join(KEYPOINT_DIR, f"{video_id}_keypoint.npy")
        mask_left_path = os.path.join(MASK_DIR, f"{video_id}_mask_left.npy")
        mask_right_path = os.path.join(MASK_DIR, f"{video_id}_mask_right.npy")
        mask_left = np.load(mask_left_path)
        mask_right = np.load(mask_right_path)
        num_frames = len(mask_left)

        # Lấy index các frame thực sự có ảnh crop left/right
        crop_left_idxs = get_sorted_crop_idxs(CROPPED_HANDS_DIR, video_id, 'left')
        crop_right_idxs = get_sorted_crop_idxs(CROPPED_HANDS_DIR, video_id, 'right')

        # Đếm số lượng 1 trong mask và số lượng ảnh crop
        n_img_left = len(crop_left_idxs)
        n_mask_left = int(mask_left.sum())
        n_img_right = len(crop_right_idxs)
        n_mask_right = int(mask_right.sum())
        if n_img_left != n_mask_left:
            print(f"[CHECK] video_id={video_id}: LEFT - {n_img_left} images, {n_mask_left} mask=1")
        if n_img_right != n_mask_right:
            print(f"[CHECK] video_id={video_id}: RIGHT - {n_img_right} images, {n_mask_right} mask=1")

        crop_left_files = [os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_left.jpg") for idx in crop_left_idxs]
        kp_j_left_files = [os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_left_kp.npy") for idx in crop_left_idxs]
        crop_right_files = [os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_right.jpg") for idx in crop_right_idxs]
        kp_j_right_files = [os.path.join(CROPPED_HANDS_DIR, f"{video_id}_frame{idx}_right_kp.npy") for idx in crop_right_idxs]

        meta = {
            "video_id": video_id,
            "keypoint_path": keypoint_path,
            "label": label,
            "num_frames": num_frames,
            "mask_left_path": mask_left_path,
            "mask_right_path": mask_right_path,
            "crop_left_files": crop_left_files,
            "kp_j_left_files": kp_j_left_files,
            "crop_right_files": crop_right_files,
            "kp_j_right_files": kp_j_right_files
        }
        meta_list.append(meta)
    return meta_list


In [25]:
# Ví dụ chạy cho 1 tập:
split = 'train'
txt_path = SPLIT_TXT[split]
meta_list = build_meta_list(txt_path)
# Nếu muốn lưu lại:
with open(f'/kaggle/working/{split}_meta.json', 'w') as f:
    json.dump(meta_list, f, indent=2)
print(f"Saved {split}_meta.json with {len(meta_list)} samples")

Saved train_meta.json with 456 samples


In [26]:
# Ví dụ chạy cho 1 tập:
split = 'test'
txt_path = SPLIT_TXT[split]
meta_list = build_meta_list(txt_path)
# Nếu muốn lưu lại:
with open(f'/kaggle/working/{split}_meta.json', 'w') as f:
    json.dump(meta_list, f, indent=2)
print(f"Saved {split}_meta.json with {len(meta_list)} samples")

Saved test_meta.json with 150 samples


In [27]:
# Ví dụ chạy cho 1 tập:
split = 'val'
txt_path = SPLIT_TXT[split]
meta_list = build_meta_list(txt_path)
# Nếu muốn lưu lại:
with open(f'/kaggle/working/{split}_meta.json', 'w') as f:
    json.dump(meta_list, f, indent=2)
print(f"Saved {split}_meta.json with {len(meta_list)} samples")

Saved val_meta.json with 145 samples


In [28]:
import torch
import torch.nn as nn

class ClassifySign(nn.Module):
    def __init__(
        self,
        vision_encoder, stgcn_body, stgcn_left, stgcn_right,
        pgf_module, temporal_encoder_body, temporal_encoder_hand, 
        num_classes,
        left_idx=21, right_idx=21, body_idx=25,
        fcn_hidden=256, dropout=0.5,
        n_joint=67, feature_dim=256
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.stgcn_body = stgcn_body
        self.stgcn_left = stgcn_left
        self.stgcn_right = stgcn_right
        self.pgf_module = pgf_module
        self.temporal_encoder_body = temporal_encoder_body
        self.temporal_encoder_hand = temporal_encoder_hand
        self.left_idx = left_idx
        self.right_idx = right_idx
        self.body_idx = body_idx
        self.n_joint = n_joint
        self.feature_dim = feature_dim

        # Tính toán số chiều đầu vào FCN sau pooling, flatten, concat
        fcn_input_dim = feature_dim * 2 + feature_dim * n_joint

        # FCN phức tạp với nhiều tầng, dropout, relu
        self.fcn = nn.Sequential(
            nn.Linear(fcn_input_dim, fcn_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(fcn_hidden, fcn_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(fcn_hidden, num_classes)
        )

    def forward(
        self, keypoint, rgb_imgs, mask_left, mask_right, 
        kp_j_left, kp_j_right, rgb_left_imgs, rgb_right_imgs
    ):
        B, T, N, C = keypoint.shape

        # 1. Chia keypoint thành 3 phần
        kp_body = keypoint[..., :self.body_idx, :]
        kp_left = keypoint[..., self.body_idx:self.body_idx+self.left_idx, :]
        kp_right = keypoint[..., self.body_idx+self.left_idx:self.body_idx+self.left_idx+self.right_idx, :]

        # 2. BODY: ST-GCN từng frame --> (B, C_body, Nb); stack lại (B, T, C_body, Nb)
        body_stgcn_feats = []
        for t in range(T):
            x = kp_body[:, t, :, :]
            _, f = self.stgcn_body(x)
            body_stgcn_feats.append(f)
        body_stgcn_feats = torch.stack(body_stgcn_feats, dim=1)
        body_out = self.temporal_encoder_body(body_stgcn_feats)

        # 3. LEFT HAND: ST-GCN + Fusion, giữ spatial, stack lại (B, T, C_hand, Nl)
        left_stgcn_feats = []
        for t in range(T):
            x = kp_left[:, t, :, :]
            _, f = self.stgcn_left(x)
            mask = mask_left[:, t]
            img = rgb_left_imgs[:, t, :, :, :]
            ref_j = kp_j_left[:, t, :, :]
            _, rgb_feat = self.vision_encoder(img, return_feature=True)
            fused = []
            for b in range(B):
                hand_feat_bn = f[b].permute(1, 0)
                if mask[b] == 1:
                    fused_feat = self.pgf_module(
                        hand_feat_bn.unsqueeze(0), rgb_feat[b:b+1], ref_j[b:b+1]
                    )
                    if fused_feat.ndim == 2:
                        fused_feat = fused_feat.unsqueeze(1).expand(-1, hand_feat_bn.shape[0], -1)
                    fused.append(fused_feat.squeeze(0).permute(1, 0))
                else:
                    fused.append(f[b])
            left_stgcn_feats.append(torch.stack(fused, dim=0))
        left_stgcn_feats = torch.stack(left_stgcn_feats, dim=1)
        left_out = self.temporal_encoder_hand(left_stgcn_feats)

        # 4. RIGHT HAND: ST-GCN + Fusion, giữ spatial, stack lại (B, T, C_hand, Nr)
        right_stgcn_feats = []
        for t in range(T):
            x = kp_right[:, t, :, :]
            _, f = self.stgcn_right(x)
            mask = mask_right[:, t]
            img = rgb_right_imgs[:, t, :, :, :]
            ref_j = kp_j_right[:, t, :, :]
            _, rgb_feat = self.vision_encoder(img, return_feature=True)
            fused = []
            for b in range(B):
                hand_feat_bn = f[b].permute(1, 0)
                if mask[b] == 1:
                    fused_feat = self.pgf_module(
                        hand_feat_bn.unsqueeze(0), rgb_feat[b:b+1], ref_j[b:b+1]
                    )
                    if fused_feat.ndim == 2:
                        fused_feat = fused_feat.unsqueeze(1).expand(-1, hand_feat_bn.shape[0], -1)
                    fused.append(fused_feat.squeeze(0).permute(1, 0))
                else:
                    fused.append(f[b])
            right_stgcn_feats.append(torch.stack(fused, dim=0))
        right_stgcn_feats = torch.stack(right_stgcn_feats, dim=1)
        right_out = self.temporal_encoder_hand(right_stgcn_feats)

        # 5. Ghép feature (B, 256, 67)
        final_feat = torch.cat([body_out, left_out, right_out], dim=-1)  # (B, 256, 67)

        # Pooling + flatten + concat
        mean_pool = final_feat.mean(dim=-1)               # (B, 256)
        max_pool = final_feat.max(dim=-1).values          # (B, 256)
        flatten = final_feat.view(B, -1)                  # (B, 256*67)
        concat_feat = torch.cat([mean_pool, max_pool, flatten], dim=-1)  # (B, 256*2 + 256*67)

        logits = self.fcn(concat_feat)
        return logits

In [29]:
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os

class SignDataset(Dataset):
    def __init__(
        self, meta_list, transform_crop=None,
        max_seq=64, crop_shape=(3, 112, 112), kp_num=21
    ):
        self.meta_list = meta_list
        self.transform_crop = transform_crop
        self.max_seq = max_seq
        self.crop_shape = crop_shape
        self.kp_num = kp_num
        # Pre-allocate zeros for padding
        self.zero_img = torch.zeros(self.crop_shape, dtype=torch.float32)
        self.zero_j = torch.zeros(self.kp_num, 2, dtype=torch.float32)

    def __len__(self):
        return len(self.meta_list)

    def pad_seq(self, seq, pad, max_len):
        # seq: list[tensor], pad: tensor
        if len(seq) >= max_len:
            return seq[:max_len]
        else:
            return seq + [pad]*(max_len-len(seq))

    def __getitem__(self, idx):
        meta = self.meta_list[idx]
        T = meta['num_frames']
        label = meta['label']
        keypoint = np.load(meta['keypoint_path'])    # (T, N, 3)
        mask_left = np.load(meta['mask_left_path'])  # (T,)
        mask_right = np.load(meta['mask_right_path'])# (T,)

        # Check consistency
        assert len(meta['crop_left_files']) == int(mask_left.sum()), f"{meta['video_id']}: crop_left_files/mask_left mismatch"
        assert len(meta['crop_right_files']) == int(mask_right.sum()), f"{meta['video_id']}: crop_right_files/mask_right mismatch"
        assert len(meta['kp_j_left_files']) == int(mask_left.sum()), f"{meta['video_id']}: kp_j_left_files/mask_left mismatch"
        assert len(meta['kp_j_right_files']) == int(mask_right.sum()), f"{meta['video_id']}: kp_j_right_files/mask_right mismatch"

        crop_left_imgs, kp_j_left = [], []
        left_idx = 0
        for t in range(T):
            if mask_left[t] == 1:
                img = Image.open(meta['crop_left_files'][left_idx]).convert("RGB")
                if self.transform_crop:
                    img = self.transform_crop(img)
                j = torch.tensor(np.load(meta['kp_j_left_files'][left_idx]), dtype=torch.float32)
                left_idx += 1
            else:
                img = self.zero_img
                j = self.zero_j
            crop_left_imgs.append(img)
            kp_j_left.append(j)
        crop_left_imgs = torch.stack(crop_left_imgs, dim=0)   # (T, C, H, W)
        kp_j_left = torch.stack(kp_j_left, dim=0)             # (T, kp_num, 2)

        crop_right_imgs, kp_j_right = [], []
        right_idx = 0
        for t in range(T):
            if mask_right[t] == 1:
                img = Image.open(meta['crop_right_files'][right_idx]).convert("RGB")
                if self.transform_crop:
                    img = self.transform_crop(img)
                j = torch.tensor(np.load(meta['kp_j_right_files'][right_idx]), dtype=torch.float32)
                right_idx += 1
            else:
                img = self.zero_img
                j = self.zero_j
            crop_right_imgs.append(img)
            kp_j_right.append(j)
        crop_right_imgs = torch.stack(crop_right_imgs, dim=0)
        kp_j_right = torch.stack(kp_j_right, dim=0)

        # Pad/truncate
        crop_left_imgs = self.pad_seq(list(crop_left_imgs), self.zero_img, self.max_seq)
        crop_right_imgs = self.pad_seq(list(crop_right_imgs), self.zero_img, self.max_seq)
        kp_j_left = self.pad_seq(list(kp_j_left), self.zero_j, self.max_seq)
        kp_j_right = self.pad_seq(list(kp_j_right), self.zero_j, self.max_seq)
        mask_left = np.pad(mask_left, (0, max(self.max_seq - T, 0)), 'constant')
        mask_right = np.pad(mask_right, (0, max(self.max_seq - T, 0)), 'constant')
        keypoint = np.pad(keypoint, ((0, max(self.max_seq - T, 0)), (0,0), (0,0)), 'constant')

        crop_left_imgs = torch.stack(crop_left_imgs, dim=0)
        crop_right_imgs = torch.stack(crop_right_imgs, dim=0)
        kp_j_left = torch.stack(kp_j_left, dim=0)
        kp_j_right = torch.stack(kp_j_right, dim=0)
        keypoint = torch.tensor(keypoint, dtype=torch.float32)
        mask_left = torch.tensor(mask_left, dtype=torch.float32)
        mask_right = torch.tensor(mask_right, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        if keypoint.shape[0] != self.max_seq:
            print(f"[WRONG SHAPE] idx={idx}, video_id={meta['video_id']}, path={meta['keypoint_path']}, shape={keypoint.shape}")

        return {
            'keypoint': keypoint,             # (T, N, 3)
            'mask_left': mask_left,           # (T,)
            'mask_right': mask_right,         # (T,)
            'crop_left_imgs': crop_left_imgs, # (T, C, H, W)
            'crop_right_imgs': crop_right_imgs,
            'kp_j_left': kp_j_left,           # (T, kp_num, 2)
            'kp_j_right': kp_j_right,
            'label': label
        }

In [42]:
import torch
try:
    torch.cuda.empty_cache()
    torch.cuda.manual_seed_all(42)
    print("OK")
except Exception as e:
    print(e)

OK


In [43]:
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def collate_fn(batch):
    batch_out = {}
    for k in batch[0].keys():
        if isinstance(batch[0][k], torch.Tensor):
            batch_out[k] = torch.stack([item[k] for item in batch], dim=0)
        else:
            arr = [item[k] for item in batch]
            if isinstance(arr[0], (int, float)):
                batch_out[k] = torch.tensor(arr)
            else:
                batch_out[k] = arr
    return batch_out

def print_batch_info(batch, batch_idx=None):
    if batch_idx is not None:
        print(f"BATCH ERROR at batch_idx={batch_idx}")
    for k in batch:
        v = batch[k]
        if isinstance(v, torch.Tensor):
            print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}, min={v.min().item() if v.numel()>0 else 'EMPTY'}, max={v.max().item() if v.numel()>0 else 'EMPTY'}")
        elif isinstance(v, list):
            print(f"{k}: list of len {len(v)}; first element type: {type(v[0]) if len(v)>0 else 'EMPTY'}")
        else:
            print(f"{k}: type={type(v)}")

def train_epoch(model, loader, optimizer, criterion, device="cuda"):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for batch_idx, batch in enumerate(loader):
        try:
            print("batch['keypoint'].shape:", batch['keypoint'].shape)
            for k in batch:
                if isinstance(batch[k], torch.Tensor):
                    batch[k] = batch[k].to(device)
            optimizer.zero_grad()
            logits = model(
                keypoint=batch['keypoint'],
                rgb_imgs=None,
                mask_left=batch['mask_left'],
                mask_right=batch['mask_right'],
                kp_j_left=batch['kp_j_left'],
                kp_j_right=batch['kp_j_right'],
                rgb_left_imgs=batch['crop_left_imgs'],
                rgb_right_imgs=batch['crop_right_imgs']
            )
            loss = criterion(logits, batch['label'])
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch['label'].size(0)
            pred = logits.argmax(dim=1)
            total_acc += (pred == batch['label']).sum().item()
            n += batch['label'].size(0)
        except Exception as e:
            print(f"Exception in train_epoch at batch_idx={batch_idx}: {e}")
            print_batch_info(batch, batch_idx)
            raise
    return total_loss / n, total_acc / n

@torch.no_grad()
def validate_epoch(model, loader, criterion, device="cuda"):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for batch_idx, batch in enumerate(loader):
        try:
            for k in batch:
                if isinstance(batch[k], torch.Tensor):
                    batch[k] = batch[k].to(device)
            logits = model(
                keypoint=batch['keypoint'],
                rgb_imgs=None,
                mask_left=batch['mask_left'],
                mask_right=batch['mask_right'],
                kp_j_left=batch['kp_j_left'],
                kp_j_right=batch['kp_j_right'],
                rgb_left_imgs=batch['crop_left_imgs'],
                rgb_right_imgs=batch['crop_right_imgs']
            )
            loss = criterion(logits, batch['label'])
            total_loss += loss.item() * batch['label'].size(0)
            pred = logits.argmax(dim=1)
            total_acc += (pred == batch['label']).sum().item()
            n += batch['label'].size(0)
        except Exception as e:
            print(f"Exception in validate_epoch at batch_idx={batch_idx}: {e}")
            print_batch_info(batch, batch_idx)
            raise
    return total_loss / n, total_acc / n

def save_checkpoint(model, optimizer, epoch, path, best_acc=None):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    if best_acc is not None:
        state['best_acc'] = best_acc
    torch.save(state, path)

def load_checkpoint(model, optimizer, path, device='cuda'):
    state = torch.load(path, map_location=device)
    model.load_state_dict(state['model'])
    if optimizer is not None:
        optimizer.load_state_dict(state['optimizer'])
    epoch = state.get('epoch', 0)
    best_acc = state.get('best_acc', None)
    return model, optimizer, epoch, best_acc

def adjust_learning_rate(optimizer, epoch, lr, step=10, decay=0.1):
    """ Step LR decay """
    lr_new = lr * (decay ** (epoch // step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_new
    return lr_new

In [44]:
import json
#from data.sign_dataset import SignDataset
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# Load meta_list
with open('/kaggle/working/train_meta.json') as f:
    train_meta_list = json.load(f)
with open('/kaggle/working/val_meta.json') as f:
    val_meta_list = json.load(f)

train_ds = SignDataset(meta_list=train_meta_list, transform_crop=transform, max_seq=64, crop_shape=(3,112,112), kp_num=21)
val_ds = SignDataset(meta_list=val_meta_list, transform_crop=transform, max_seq=64, crop_shape=(3,112,112), kp_num=21)

#from utils.train_utils import collate_fn
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last = True) 
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=4, collate_fn=collate_fn, drop_last = True)

In [45]:
data_root = "/kaggle/input/keypoint/keypoints"
train_txt = "/kaggle/working/train.txt"
val_txt = "/kaggle/working/val.txt"
test_txt = "/kaggle/working/test.txt"
#batch_size = 32
#num_workers = 4
#num_epochs = 10
#lr = 1e-3

label_map = build_label_map([train_txt, val_txt, test_txt])
num_classes = len(label_map)

In [46]:
import torch
import torchvision
#from models.vision import VisionClassifier

def load_vision_encoder(ckpt_path, num_classes=1000):
    vision_encoder = VisionClassifier(num_classes=num_classes)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    vision_encoder.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
    return vision_encoder

vision_encoder = load_vision_encoder('/kaggle/working/pretrained_hand_rgb.pth', num_classes=num_classes)

In [47]:
#from models.skeleton import SpatialPoseEncoder

def load_stgcn(ckpt_path, in_channels, num_joints, A, hid_dim=128, out_dim=256):
    model = SpatialPoseEncoder(
        in_channels=in_channels,
        num_joints=num_joints,
        num_classes=1,  # Không quan trọng nếu chỉ lấy feature
        A = A,
        hid_dim=hid_dim,
        out_dim=out_dim
    )
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
    return model

stgcn_body = load_stgcn('/kaggle/working/spatial_body_best.pth', in_channels=3, num_joints=25, A = get_body_spatial_graph())
stgcn_left = load_stgcn('/kaggle/working/spatial_left_best.pth', in_channels=3, num_joints=21, A = get_left_hand_spatial_graph())
stgcn_right = load_stgcn('/kaggle/working/spatial_right_best.pth', in_channels=3, num_joints=21, A = get_right_hand_spatial_graph())

In [48]:
d_model = 256
max_seq = 64  # hoặc giá trị max_seq bạn sử dụng

temporal_encoder_body = TemporalEncoder(c_model=d_model, n_joints = 25, pool='mean')
temporal_encoder_hand = TemporalEncoder(c_model=d_model, n_joints = 21, pool='mean')


In [49]:
pgf_module = PGFModule(d_model=256, n_heads=8, n_points=4)

In [50]:
model = ClassifySign(
    vision_encoder = vision_encoder,
    pgf_module = pgf_module,
    stgcn_body=stgcn_body,
    stgcn_left=stgcn_left,
    stgcn_right=stgcn_right,
    temporal_encoder_body=temporal_encoder_body,
    temporal_encoder_hand=temporal_encoder_hand,
    num_classes=num_classes,  # Sửa cho đúng dataset của bạn
    # Có thể cần các tham số khác như vision_encoder, pgf_module, left_idx, right_idx,... nếu repo yêu cầu
)

In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [52]:
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs.")
    model = torch.nn.DataParallel(model)
model = model.to(device)

Using 2 GPUs.


In [53]:
set_seed(42)



optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()



num_epochs = 50
best_acc = 0.0
earlystop_patience = 7
earlystop_counter = 0

for epoch in range(num_epochs):
    adjust_learning_rate(optimizer, epoch, lr=1e-4, step=10, decay=0.1)

    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device=device)
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device=device)

    print(f"Epoch {epoch+1}/{num_epochs} | Train loss: {train_loss:.4f} acc: {train_acc:.4f} | Val loss: {val_loss:.4f} acc: {val_acc:.4f}")

    # Lưu checkpoint tốt nhất
    if val_acc > best_acc:
        best_acc = val_acc
        save_checkpoint(model, optimizer, epoch, "/kaggle/working/best_model.pth", best_acc)
        print("Saved best model.")
        earlystop_counter = 0
    else:
        earlystop_counter += 1

    # Early stopping nếu không cải thiện
    if earlystop_counter >= earlystop_patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

print("Train hoàn tất. Best val acc:", best_acc)

batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypoint'].shape: torch.Size([8, 64, 67, 3])
batch['keypo

RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
