# 1.  All classes

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from einops import rearrange
import torchvision

In [2]:
# --------- Graph and ST-GCN -------------
class Graph:
    """Graph for skeleton representation (e.g., Mediapipe keypoints)."""
    def __init__(self, layout='body', strategy='uniform', max_hop=1, dilation=1):
        self.max_hop = max_hop
        self.dilation = dilation
        self.get_edge(layout)
        self.hop_dis = self.get_hop_distance(self.num_node, self.edge, max_hop=max_hop)
        self.A = self.get_adjacency(strategy)

    def get_edge(self, layout):
        if layout == 'body':
            self.num_node = 25
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_1base = [
                [0, 1],[1, 2],[2, 3],[3, 7],
                [0, 4],[4, 5],[5, 6],[6, 8],
                [9, 10],[11, 12],[11, 13],[13, 15],[15, 21],[15, 19],[15, 17],
                [17, 19],[11, 23],[12, 14],[14, 16],[16, 18],[16, 20],[16, 22],
                [18, 20],[12, 24],[23, 24]
            ]
            neighbor_link = neighbor_1base
            self.edge = self_link + neighbor_link
            self.center = 0
        elif layout == 'left' or layout == 'right':
            self.num_node = 21
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_1base = [
                [0, 1],[1, 2],[2, 3],[3, 4],
                [0, 5],[5, 6],[6, 7],[7, 8],
                [0, 9],[9, 10],[10, 11],[11, 12],
                [0, 13],[13, 14],[14, 15],[15, 16],
                [0, 17],[17, 18],[18, 19],[19, 20]
            ]
            neighbor_link = neighbor_1base
            self.edge = self_link + neighbor_link
            self.center = 0
        else:
            raise ValueError("Unknown layout: {}".format(layout))

    def get_hop_distance(self, num_node, edge, max_hop=1):
        A = np.zeros((num_node, num_node))
        for i, j in edge:
            A[j, i] = 1
            A[i, j] = 1
        hop_dis = np.zeros((num_node, num_node)) + np.inf
        transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
        arrive_mat = np.stack(transfer_mat) > 0
        for d in range(max_hop, -1, -1):
            hop_dis[arrive_mat[d]] = d
        return hop_dis

    def normalize_digraph(self, A):
        Dl = np.sum(A, 0)
        num_node = A.shape[0]
        Dn = np.zeros((num_node, num_node))
        for i in range(num_node):
            if Dl[i] > 0:
                Dn[i, i] = Dl[i] ** (-1)
        AD = np.dot(A, Dn)
        return AD

    def get_adjacency(self, strategy):
        valid_hop = range(0, self.max_hop + 1, self.dilation)
        adjacency = np.zeros((self.num_node, self.num_node))
        for hop in valid_hop:
            adjacency[self.hop_dis == hop] = 1
        normalize_adjacency = self.normalize_digraph(adjacency)
        if strategy == 'uniform':
            A = np.zeros((1, self.num_node, self.num_node))
            A[0] = normalize_adjacency
        else:
            raise NotImplementedError
        return torch.tensor(A, dtype=torch.float32)

In [3]:
class STGCNBlock(nn.Module):
    """A single ST-GCN block (spatio-temporal)."""
    def __init__(self, in_channels, out_channels, A, kernel_size=1, stride=1, dropout=0, residual=True):
        super().__init__()
        self.A = A
        self.gcn = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        self.tcn = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, (9, 1), (stride, 1), (4, 0)),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True)
        )
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(stride, 1)),
                nn.BatchNorm2d(out_channels)
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: (B, C, T, N)
        y = self.gcn(x)
        y = torch.einsum('nctv,kvw->nctw', (y, self.A))
        y = self.tcn(y)
        y = y + self.residual(x)
        return self.relu(y)

In [4]:
class PoseEncoder(nn.Module):
    """Pose encoder with 3 stacked ST-GCN blocks."""
    def __init__(self, A, in_dim=3, hid_dim=64, out_dim=256):
        super().__init__()
        self.input_proj = nn.Linear(in_dim, hid_dim)
        self.stgcn1 = STGCNBlock(hid_dim, 128, A)
        self.stgcn2 = STGCNBlock(128, 256, A)
        self.stgcn3 = STGCNBlock(256, out_dim, A)
    def forward(self, x):
        # x: (B, T, N, 3)
        x = self.input_proj(x)           # (B, T, N, hid_dim)
        x = x.permute(0, 3, 1, 2)        # (B, hid_dim, T, N)
        x = self.stgcn1(x)
        x = self.stgcn2(x)
        x = self.stgcn3(x)
        x = x.mean(2)                    # (B, out_dim, N), mean over time
        return x

In [5]:
# --------- Deformable Attention (Fusion) -------------
class DeformableAttention2D(nn.Module):
    """Deformable attention for pose-vision fusion."""
    def __init__(self, dim, dim_head=64, heads=8, offset_groups=8):
        super().__init__()
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5
        self.offset_groups = offset_groups
        inner_dim = dim_head * heads
        self.to_q = nn.Conv1d(dim, inner_dim, 1, bias=False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
        self.to_offsets = nn.Sequential(
            nn.Conv1d(dim, dim, 1),
            nn.ReLU(),
            nn.Conv1d(dim, offset_groups * 2, 1)
        )
        self.proj = nn.Conv1d(inner_dim, dim, 1)

    def forward(self, query_feat, context_feat, ref_points):
        # query_feat: (B, C, N), context_feat: (B, C, H, W), ref_points: (B, N, 2) normalized [0,1]
        B, C, N = query_feat.shape
        _, _, H, W = context_feat.shape
        q = self.to_q(query_feat)
        q = q.view(B, self.heads, self.dim_head, N).permute(0, 1, 3, 2)  # (B, heads, N, dim_head)
        kv = self.to_kv(context_feat)
        kv = kv.view(B, 2, self.heads, self.dim_head, H, W)
        k, v = kv[:, 0], kv[:, 1]
        offsets = self.to_offsets(query_feat).view(B, self.offset_groups, 2, N).permute(0, 1, 3, 2)  # (B, G, N, 2)
        ref_points = ref_points.unsqueeze(1).repeat(1, self.offset_groups, 1, 1)  # (B, G, N, 2)
        coords = ref_points + offsets / torch.tensor([W, H], device=query_feat.device)
        coords = coords.clamp(0, 1)
        coords = coords.view(B, self.offset_groups * N, 2)
        coords = coords * 2 - 1
        coords = coords.view(B, 1, self.offset_groups * N, 1, 2)
        k = k.view(B * self.heads, self.dim_head, H, W)
        v = v.view(B * self.heads, self.dim_head, H, W)
        k_sampled = F.grid_sample(k, coords.expand(-1, self.dim_head, -1, -1, -1), align_corners=True)
        v_sampled = F.grid_sample(v, coords.expand(-1, self.dim_head, -1, -1, -1), align_corners=True)
        k_sampled = k_sampled.squeeze(-1).view(B, self.heads, self.dim_head, N, self.offset_groups)
        v_sampled = v_sampled.squeeze(-1).view(B, self.heads, self.dim_head, N, self.offset_groups)
        q = q.unsqueeze(-1)  # (B, heads, N, dim_head, 1)
        attn = (q * k_sampled).sum(3) * self.scale  # (B, heads, N, G)
        attn = F.softmax(attn, dim=-1)
        out = (attn.unsqueeze(3) * v_sampled).sum(-1)  # (B, heads, dim_head, N)
        out = out.permute(0, 1, 3, 2).contiguous().view(B, -1, N)
        out = self.proj(out)
        return out

In [6]:
class PGFModule(nn.Module):
    """Pose-guided fusion of pose and vision feature."""
    def __init__(self, dim, dim_head=64, heads=8):
        super().__init__()
        self.gater = nn.Sequential(
            nn.Conv1d(dim * 2, dim, 1),
            nn.ReLU(),
            nn.Conv1d(dim, dim, 1),
            nn.Sigmoid()
        )
        self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
        self.deform_attn = DeformableAttention2D(dim, dim_head, heads, offset_groups=heads)

    def forward(self, pose_feat, vision_feat, J):
        # pose_feat: (B, C, N), vision_feat: (B, C, H, W), J: (B, N, 2)
        B, C, N = pose_feat.shape
        Fp = pose_feat
        Fr = vision_feat
        Fp_ = Fp.permute(0, 2, 1)  # (B, N, C)
        Fr_ = Fr.flatten(2).permute(0, 2, 1)  # (B, HW, C)
        attn_output, _ = self.self_attn(Fp_, Fr_, Fr_)
        attn_output = attn_output.permute(0, 2, 1)  # (B, C, N)
        gater_input = torch.cat([Fp, attn_output], dim=1)
        gate = self.gater(gater_input)
        Fp_fused = Fp * gate + attn_output * (1 - gate)
        fused = self.deform_attn(Fp_fused, Fr, J)
        return fused

In [7]:
# --------- Vision Encoder -------------
class VisionEncoder(nn.Module):
    def __init__(self, out_dim=256):
        super().__init__()
        backbone = torchvision.models.efficientnet_b0(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(backbone.children())[:-2])
        self.proj = nn.Conv2d(1280, out_dim, 1)
    def forward(self, x):
        # x: (B, 3, H, W)
        feat = self.feature_extractor(x)  # (B, 1280, H', W')
        return self.proj(feat)            # (B, out_dim, H', W')

In [8]:
# --------- Dataset Example -------------
class SignKeypointDataset(Dataset):
    """Dataset returning (keypoint, label) with file check."""
    def __init__(self, data_root, video_list_txt, part='body', label_map=None):
        self.samples = []
        with open(video_list_txt, 'r') as f:
            for line in f:
                items = line.strip().split()
                if len(items) == 2:
                    vid, label = items
                    path = os.path.join(data_root, vid, "keypoints.npy")
                    if os.path.exists(path):
                        label_idx = label_map[label] if label_map else int(label)
                        self.samples.append((path, label_idx))
        self.part = part

    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        kp_path, label = self.samples[idx]
        keypoints = np.load(kp_path)  # (T, 67, 3)
        if self.part == 'body':
            part_kp = keypoints[:, :25, :]
        elif self.part == 'left':
            part_kp = keypoints[:, 25:46, :]
        else:
            part_kp = keypoints[:, 46:, :]
        return torch.tensor(part_kp, dtype=torch.float32), label

# 2. Pretrain Vision encoder 

In [9]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import torchvision
from tqdm import tqdm
import json

In [10]:
def build_label_map(txt_files):
    labels = set()
    for txt in txt_files:
        with open(txt, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    label = ' '.join(parts[1:]).strip()
                    labels.add(label)
    labels = sorted(labels)
    label_map = {lbl: idx for idx, lbl in enumerate(labels)}
    return label_map

In [11]:
data_root = "/kaggle/input/data-wlasl/DATA"
train_txt = "/kaggle/input/train-test-valid/train.txt"
valid_txt = "/kaggle/input/train-test-valid/val.txt"
test_txt = "/kaggle/input/train-test-valid/test.txt"
batch_size = 64
num_workers = 4
num_epochs = 6
lr = 1e-3
out_ckpt = "/kaggle/working/pretrained_hand_rgb.pth"
out_labelmap = "/kaggle/working/label_map.json"

label_map = build_label_map([train_txt, valid_txt, test_txt])
num_classes = len(label_map)
print("Số lớp:", num_classes)
print("Sample label_map:", dict(list(label_map.items())[:5]))

with open(out_labelmap, "w") as f:
    json.dump(label_map, f)

Số lớp: 355
Sample label_map: {'accept': 0, 'accident': 1, 'add': 2, 'africa': 3, 'again': 4}


In [12]:
# ==== 2. Dataset for cropped hand images ====
class HandImageDataset(Dataset):
    def __init__(self, data_root, list_file, label_map, transform=None):
        self.samples = []
        self.transform = transform
        with open(list_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 2:
                    video_id, label = parts
                    label_idx = label_map[label]
                    img_dir = os.path.join(data_root, video_id)
                    # All *_left.jpg and *_right.jpg (can add filter for frame sampling if needed)
                    imgs = sorted(glob(os.path.join(img_dir, "*_left.jpg"))) + \
                           sorted(glob(os.path.join(img_dir, "*_right.jpg")))
                    for img_path in imgs:
                        if os.path.isfile(img_path):
                            self.samples.append((img_path, label_idx))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [13]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

train_ds = HandImageDataset(data_root, train_txt, label_map, transform=transform)
val_ds = HandImageDataset(data_root, valid_txt, label_map, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print(f"Số ảnh train: {len(train_ds)}, Số ảnh val: {len(val_ds)}")

Số ảnh train: 12891, Số ảnh val: 2555


In [14]:
# ==== 3. Vision Model ====
class VisionClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        backbone = torchvision.models.efficientnet_b0(pretrained=True)
        self.features = backbone.features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(1280, num_classes)
    def forward(self, x):
        feat = self.features(x)
        feat = self.pool(feat).view(x.size(0), -1)
        return self.classifier(feat)


In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionClassifier(num_classes)
if torch.cuda.device_count() > 1:
    print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
    model = nn.DataParallel(model)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 136MB/s] 


Using DataParallel with 2 GPUs


In [16]:
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}", leave=False):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        logits = model(imgs)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    train_acc = correct / total
    train_loss = total_loss / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Valid", leave=False):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            logits = model(imgs)
            loss = criterion(logits, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(1)
            val_correct += (preds == labels).sum().item()
            val_total += imgs.size(0)
    val_acc = val_correct / val_total
    val_loss = val_loss / val_total

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model": model.state_dict(),
            "label_map": label_map
        }, out_ckpt)
        print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

print("Done. Best val acc:", best_val_acc)

                                                                

Epoch 1/6 | Train Loss: 5.2582 Acc: 0.0600 | Val Loss: 5.5364 Acc: 0.0556
Best model saved at epoch 1, val_acc=0.0556


                                                                

Epoch 2/6 | Train Loss: 3.8517 Acc: 0.2129 | Val Loss: 5.3591 Acc: 0.0779
Best model saved at epoch 2, val_acc=0.0779


                                                                

Epoch 3/6 | Train Loss: 2.9059 Acc: 0.3547 | Val Loss: 5.6515 Acc: 0.0916
Best model saved at epoch 3, val_acc=0.0916


                                                                

Epoch 4/6 | Train Loss: 2.1887 Acc: 0.4837 | Val Loss: 5.9073 Acc: 0.0939
Best model saved at epoch 4, val_acc=0.0939


                                                                

Epoch 5/6 | Train Loss: 1.6259 Acc: 0.6020 | Val Loss: 6.6308 Acc: 0.0935


                                                                

Epoch 6/6 | Train Loss: 1.1757 Acc: 0.7018 | Val Loss: 6.6481 Acc: 0.0986
Best model saved at epoch 6, val_acc=0.0986
Done. Best val acc: 0.09863013698630137




# 3. Pretrain Pose encoder với Spatial GCN

In [33]:
import numpy as np
import torch
from torch.utils.data import Dataset

class PoseSpatialPartDataset(Dataset):
    def __init__(self, data_root, txt_file, label_map, part='body'):
        self.samples = []
        self.label_map = label_map
        self.part = part
        with open(txt_file) as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 2:
                    print(f"WARNING: dòng bị lỗi format: {line}")
                    continue
                npy_id = parts[0]
                label = ' '.join(parts[1:]).strip()
                if label not in label_map:
                    print(f"WARNING: label '{label}' chưa có trong label_map!")
                    continue
                npy_path = f"{data_root}/{npy_id}/keypoints.npy"
                self.samples.append((npy_path, int(label_map[label])))
        # Define keypoint slices for each part
        if part == 'body':
            self.idx_start, self.idx_end = 0, 25
        elif part == 'left':
            self.idx_start, self.idx_end = 25, 46
        elif part == 'right':
            self.idx_start, self.idx_end = 46, 67
        else:
            raise ValueError("Unknown part: " + part)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npy_path, label = self.samples[idx]
        keypoints = np.load(npy_path)  # (T, 67, 3) expected
        # Auto fix shape if needed
        if keypoints.shape[-2:] == (67, 3):
            pass
        elif keypoints.shape[0] == 67 and keypoints.shape[1] == 3:
            keypoints = np.transpose(keypoints, (2, 0, 1))
        elif keypoints.shape[1] == 3 and keypoints.shape[2] == 67:
            keypoints = np.transpose(keypoints, (0, 2, 1))
        else:
            raise RuntimeError(f"Unrecognized keypoints shape: {keypoints.shape}")
        part_kp = keypoints[:, self.idx_start:self.idx_end, :]  # (T, N, 3)
        if idx == 0:
            print(f"Dataset part_kp.shape: {part_kp.shape}")
        return torch.tensor(part_kp, dtype=torch.float32), label

In [18]:
class SpatialGCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, A):
        super().__init__()
        self.register_buffer('A', A)
        self.fc = nn.Linear(in_channels, out_channels)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):  # x: (B, N, in_channels)
        h = self.fc(x)  # (B, N, out_channels)
        h = h.permute(0, 2, 1)  # (B, out_channels, N)
        h = torch.matmul(h, self.A)  # (B, out_channels, N)
        h = self.bn(h)  # BatchNorm trên out_channels
        h = h.permute(0, 2, 1)  # (B, N, out_channels)
        return torch.relu(h)

In [36]:
class SpatialPoseEncoder(nn.Module):
    def __init__(self, in_channels, num_joints, num_classes, A, hid_dim=128, out_dim=256):
        super().__init__()
        self.gcn1 = SpatialGCNLayer(in_channels, hid_dim, A)
        self.gcn2 = SpatialGCNLayer(hid_dim, out_dim, A)
        self.classifier = nn.Linear(out_dim * num_joints, num_classes)

    def forward(self, x):  # x: (B, T, N, 3)
        #print("Encoder x.shape:", x.shape)
        B, T, N, C = x.shape
        assert N == self.gcn1.A.shape[0], f"x.shape={x.shape}, A.shape={self.gcn1.A.shape}"
        x = x.view(B * T, N, C)  # (B*T, N, 3)
        h = self.gcn1(x)         # (B*T, N, hid_dim)
        h = self.gcn2(h)         # (B*T, N, out_dim)
        h = h.view(B, T, N, -1)  # (B, T, N, out_dim)
        h = h.mean(1)            # (B, N, out_dim)
        h = h.permute(0, 2, 1)   # (B, out_dim, N)
        logits = self.classifier(h.flatten(1))  # (B, num_classes)
        return logits, h

In [20]:
# ---- Example: Fusion lấy feature frame đặc biệt ----
def gather_special_frames(pose_feat, mask_indices):
    """
    pose_feat: (B, C, N, T)
    mask_indices: list of [tensor(F_b,), ...]  # F_b: số frame của mỗi sample cần fusion
    Return: list of (B, C, N, F_b)
    """
    outputs = []
    for b, idxs in enumerate(mask_indices):
        # idxs: (F_b,), pose_feat[b]: (C, N, T)
        sel = pose_feat[b, :, :, idxs]  # (C, N, F_b)
        outputs.append(sel)
    return outputs


In [21]:
def get_spatial_adjacency(num_node, edge):
    A = np.zeros((num_node, num_node))
    for i, j in edge:
        A[i, j] = 1
        A[j, i] = 1
    # Normalize
    Dl = np.sum(A, 0)
    Dn = np.zeros((num_node, num_node))
    for i in range(num_node):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    A_normalized = np.dot(A, Dn)
    return torch.tensor(A_normalized, dtype=torch.float32)

def get_body_spatial_graph():
    # 25 body keypoints (Mediapipe hoặc OpenPose định nghĩa)
    num_node = 25
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1), (1, 2), (2, 3), (3, 7),
        (0, 4), (4, 5), (5, 6), (6, 8),
        (9, 10), (11, 12), (11, 13), (13, 15), (15, 21), (15, 19), (15, 17),
        (17, 19), (11, 23), (12, 14), (14, 16), (16, 18), (16, 20), (16, 22),
        (18, 20), (12, 24), (23, 24)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

def get_left_hand_spatial_graph():
    # 21 left hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

def get_right_hand_spatial_graph():
    # 21 right hand keypoints (Mediapipe)
    num_node = 21
    self_link = [(i, i) for i in range(num_node)]
    neighbor_link = [
        (0, 1),(1, 2),(2, 3),(3, 4),
        (0, 5),(5, 6),(6, 7),(7, 8),
        (0, 9),(9, 10),(10, 11),(11, 12),
        (0, 13),(13, 14),(14, 15),(15, 16),
        (0, 17),(17, 18),(18, 19),(19, 20)
    ]
    edge = self_link + neighbor_link
    return get_spatial_adjacency(num_node, edge)

In [22]:
# ========= 2. Dataset =========
PART_INFO = {
    'body':  (0, 25),
    'left':  (25, 46),
    'right': (46, 67)
}

In [38]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

def train_pose_spatial_part(part, num_joints, get_A_func, best_ckpt_file):
    print(f"\n--- Pretraining {part} ---")
    train_ds = PoseSpatialPartDataset(data_root, train_txt, label_map, part=part)
    val_ds = PoseSpatialPartDataset(data_root, val_txt, label_map, part=part)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Số mẫu train: {len(train_ds)}, val: {len(val_ds)}")
    print(f"Số batch train: {len(train_loader)}, val: {len(val_loader)}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    A = get_A_func().to(device)
    model = SpatialPoseEncoder(in_channels=3, num_joints=num_joints, num_classes=num_classes, A=A)
    if torch.cuda.device_count() > 1:
        print("Using DataParallel with {} GPUs".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, labels in tqdm(train_loader, desc=f"Train {part} Epoch {epoch+1}", leave=False):
            #print("Batch x.shape:", x.shape)
            x, labels = x.to(device), labels.to(device)
            logits, _ = model(x)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += x.size(0)
        if total > 0:
            train_acc = correct / total
            train_loss = total_loss / total
        else:
            train_acc = 0
            train_loss = 0
            print("WARNING: Không có sample nào trong batch train!")

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for x, labels in tqdm(val_loader, desc=f"Val {part} Epoch {epoch+1}", leave=False):
                x, labels = x.to(device), labels.to(device)
                logits, _ = model(x)
                loss = criterion(logits, labels)
                val_loss += loss.item() * x.size(0)
                preds = logits.argmax(1)
                val_correct += (preds == labels).sum().item()
                val_total += x.size(0)
        if val_total > 0:
            val_acc = val_correct / val_total
            val_loss = val_loss / val_total
        else:
            val_acc = 0
            val_loss = 0
            print("WARNING: Không có sample nào trong batch val!")

        print(f"[{part}] Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({"model": model.state_dict(), "label_map": label_map}, best_ckpt_file)
            print(f"Best model saved at epoch {epoch+1}, val_acc={val_acc:.4f}")

    print(f"Done {part}. Best val acc: {best_val_acc:.4f}")

In [40]:
data_root = "/kaggle/input/data-wlasl/DATA"
train_txt = "/kaggle/input/train-test-valid/train.txt"
val_txt = "/kaggle/input/train-test-valid/val.txt"
test_txt = "/kaggle/input/train-test-valid/test.txt"
batch_size = 32
num_workers = 4
num_epochs = 10
lr = 1e-3

label_map = build_label_map([train_txt, val_txt, test_txt])
num_classes = len(label_map)
with open("label_map_pose.json", "w") as f:
    json.dump(label_map, f)

In [25]:
len(label_map)

355

In [26]:
print(get_body_spatial_graph().shape)       # (25, 25)
print(get_left_hand_spatial_graph().shape)  # (21, 21)
print(get_right_hand_spatial_graph().shape) # (21, 21)

torch.Size([25, 25])
torch.Size([21, 21])
torch.Size([21, 21])


In [41]:
train_pose_spatial_part('body',  num_joints=25, get_A_func=get_body_spatial_graph, best_ckpt_file="/kaggle/working/spatial_body_best.pth")
train_pose_spatial_part('left',  num_joints=21, get_A_func=get_left_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_left_best.pth")
train_pose_spatial_part('right', num_joints=21, get_A_func=get_right_hand_spatial_graph, best_ckpt_file="/kaggle/working/spatial_right_best.pth")


--- Pretraining body ---
Số mẫu train: 2340, val: 464
Số batch train: 74, val: 15
Using DataParallel with 2 GPUs


Train body Epoch 1:  30%|██▉       | 22/74 [00:00<00:01, 30.96it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 1/10 | Train Loss: 8.1367 Acc: 0.0060 | Val Loss: 6.6547 Acc: 0.0172
Best model saved at epoch 1, val_acc=0.0172


Train body Epoch 2:  64%|██████▎   | 47/74 [00:00<00:00, 61.15it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 2/10 | Train Loss: 6.0157 Acc: 0.0248 | Val Loss: 6.6470 Acc: 0.0194
Best model saved at epoch 2, val_acc=0.0194


Train body Epoch 3:  76%|███████▌  | 56/74 [00:01<00:00, 61.73it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 3/10 | Train Loss: 5.3881 Acc: 0.0466 | Val Loss: 6.0784 Acc: 0.0302
Best model saved at epoch 3, val_acc=0.0302


Train body Epoch 4:  68%|██████▊   | 50/74 [00:01<00:00, 56.97it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 4/10 | Train Loss: 4.8871 Acc: 0.0714 | Val Loss: 6.0456 Acc: 0.0366
Best model saved at epoch 4, val_acc=0.0366


Train body Epoch 5:  46%|████▌     | 34/74 [00:00<00:00, 56.39it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 5/10 | Train Loss: 4.5587 Acc: 0.0902 | Val Loss: 5.8852 Acc: 0.0237


Train body Epoch 6:  66%|██████▌   | 49/74 [00:00<00:00, 62.00it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 6/10 | Train Loss: 4.2505 Acc: 0.1179 | Val Loss: 5.8830 Acc: 0.0496
Best model saved at epoch 6, val_acc=0.0496


Train body Epoch 7:  27%|██▋       | 20/74 [00:00<00:01, 48.51it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                

[body] Epoch 7/10 | Train Loss: 4.0637 Acc: 0.1261 | Val Loss: 5.8756 Acc: 0.0431


Train body Epoch 8:  19%|█▉        | 14/74 [00:00<00:01, 41.46it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 8/10 | Train Loss: 3.9039 Acc: 0.1419 | Val Loss: 5.8607 Acc: 0.0496


Train body Epoch 9:   8%|▊         | 6/74 [00:00<00:02, 23.01it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                

[body] Epoch 9/10 | Train Loss: 3.7417 Acc: 0.1462 | Val Loss: 5.8270 Acc: 0.0539
Best model saved at epoch 9, val_acc=0.0539


Train body Epoch 10:  86%|████████▋ | 64/74 [00:01<00:00, 63.92it/s]

Dataset part_kp.shape: (64, 25, 3)


Val body Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 25, 3)


                                                                 

[body] Epoch 10/10 | Train Loss: 3.6070 Acc: 0.1744 | Val Loss: 5.9243 Acc: 0.0582
Best model saved at epoch 10, val_acc=0.0582
Done body. Best val acc: 0.0582

--- Pretraining left ---
Số mẫu train: 2340, val: 464
Số batch train: 74, val: 15
Using DataParallel with 2 GPUs


Train left Epoch 1:  68%|██████▊   | 50/74 [00:00<00:00, 63.37it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 1/10 | Train Loss: 7.6256 Acc: 0.0073 | Val Loss: 6.2909 Acc: 0.0216
Best model saved at epoch 1, val_acc=0.0216


Train left Epoch 2:   0%|          | 0/74 [00:00<?, ?it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 2/10 | Train Loss: 5.8474 Acc: 0.0291 | Val Loss: 5.8927 Acc: 0.0216


Train left Epoch 3:  46%|████▌     | 34/74 [00:00<00:00, 59.93it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 3/10 | Train Loss: 5.2596 Acc: 0.0530 | Val Loss: 5.7182 Acc: 0.0453
Best model saved at epoch 3, val_acc=0.0453


Train left Epoch 4:  30%|██▉       | 22/74 [00:00<00:01, 51.77it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 4/10 | Train Loss: 4.8199 Acc: 0.0795 | Val Loss: 5.2725 Acc: 0.0431


Train left Epoch 5:  38%|███▊      | 28/74 [00:00<00:00, 55.61it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 5/10 | Train Loss: 4.4395 Acc: 0.1111 | Val Loss: 5.2689 Acc: 0.0647
Best model saved at epoch 5, val_acc=0.0647


Train left Epoch 6:   9%|▉         | 7/74 [00:00<00:02, 26.72it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 6/10 | Train Loss: 4.2360 Acc: 0.1312 | Val Loss: 5.2525 Acc: 0.0582


Train left Epoch 7:   1%|▏         | 1/74 [00:00<00:14,  5.06it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 7/10 | Train Loss: 4.0316 Acc: 0.1393 | Val Loss: 5.0787 Acc: 0.1013
Best model saved at epoch 7, val_acc=0.1013


Train left Epoch 8:  58%|█████▊    | 43/74 [00:00<00:00, 61.60it/s]

Dataset part_kp.shape: (64, 21, 3)

Train left Epoch 8:  69%|██████▉   | 51/74 [00:00<00:00, 64.21it/s]




Val left Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 8/10 | Train Loss: 3.8401 Acc: 0.1739 | Val Loss: 5.0418 Acc: 0.0905


Train left Epoch 9:  68%|██████▊   | 50/74 [00:00<00:00, 64.62it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                

[left] Epoch 9/10 | Train Loss: 3.6503 Acc: 0.1983 | Val Loss: 5.3098 Acc: 0.0754


Train left Epoch 10:  82%|████████▏ | 61/74 [00:01<00:00, 65.79it/s]

Dataset part_kp.shape: (64, 21, 3)


Val left Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[left] Epoch 10/10 | Train Loss: 3.5234 Acc: 0.2286 | Val Loss: 4.9160 Acc: 0.1121
Best model saved at epoch 10, val_acc=0.1121
Done left. Best val acc: 0.1121

--- Pretraining right ---
Số mẫu train: 2340, val: 464
Số batch train: 74, val: 15
Using DataParallel with 2 GPUs


Train right Epoch 1:  20%|██        | 15/74 [00:00<00:01, 42.67it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 1:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 1/10 | Train Loss: 7.5902 Acc: 0.0081 | Val Loss: 6.2549 Acc: 0.0086
Best model saved at epoch 1, val_acc=0.0086


Train right Epoch 2:  45%|████▍     | 33/74 [00:00<00:00, 54.38it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 2:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 2/10 | Train Loss: 6.0033 Acc: 0.0192 | Val Loss: 6.0601 Acc: 0.0216
Best model saved at epoch 2, val_acc=0.0216


Train right Epoch 3:  47%|████▋     | 35/74 [00:00<00:00, 58.67it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 3:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 3/10 | Train Loss: 5.5261 Acc: 0.0329 | Val Loss: 5.5489 Acc: 0.0280
Best model saved at epoch 3, val_acc=0.0280


Train right Epoch 4:  31%|███       | 23/74 [00:00<00:00, 52.14it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 4:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 4/10 | Train Loss: 5.1882 Acc: 0.0423 | Val Loss: 5.5168 Acc: 0.0302
Best model saved at epoch 4, val_acc=0.0302


Train right Epoch 5:  19%|█▉        | 14/74 [00:00<00:01, 41.47it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 5:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 5/10 | Train Loss: 4.9386 Acc: 0.0543 | Val Loss: 5.3314 Acc: 0.0453
Best model saved at epoch 5, val_acc=0.0453


Train right Epoch 6:  49%|████▊     | 36/74 [00:00<00:00, 58.85it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 6:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 6/10 | Train Loss: 4.7297 Acc: 0.0731 | Val Loss: 5.2077 Acc: 0.0474
Best model saved at epoch 6, val_acc=0.0474


Train right Epoch 7:  68%|██████▊   | 50/74 [00:00<00:00, 63.78it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 7:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                  

[right] Epoch 7/10 | Train Loss: 4.6145 Acc: 0.0705 | Val Loss: 5.4570 Acc: 0.0388


Train right Epoch 8:   9%|▉         | 7/74 [00:00<00:02, 27.78it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 8:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 8/10 | Train Loss: 4.4385 Acc: 0.0915 | Val Loss: 5.2350 Acc: 0.0647
Best model saved at epoch 8, val_acc=0.0647


Train right Epoch 9:  86%|████████▋ | 64/74 [00:01<00:00, 67.78it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 9:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                 

[right] Epoch 9/10 | Train Loss: 4.3654 Acc: 0.0979 | Val Loss: 5.1947 Acc: 0.0647


Train right Epoch 10:  36%|███▋      | 27/74 [00:00<00:00, 55.61it/s]

Dataset part_kp.shape: (64, 21, 3)


Val right Epoch 10:   0%|          | 0/15 [00:00<?, ?it/s]           

Dataset part_kp.shape: (64, 21, 3)


                                                                   

[right] Epoch 10/10 | Train Loss: 4.2778 Acc: 0.1188 | Val Loss: 5.1152 Acc: 0.0560
Done right. Best val acc: 0.0647


