# Dataset visualize
labels có videos <6 - bỏ qua

# Keypoints

In [1]:
import os
import shutil

# Spatial - Temporal graph convolution N.

In [2]:
import torch
import numpy as np
import torch.nn as nn
import pdb
import math
import copy

In [3]:
class Graph:
    """The Graph to model the skeletons extracted by the Mediapipe

    Args:
        strategy (string): must be one of the follow candidates
        - uniform: Uniform Labeling
        - distance: Distance Partitioning
        - spatial: Spatial Configuration
        For more information, please refer to the section 'Partition Strategies'
            in our paper (https://arxiv.org/abs/1801.07455).

        layout (string): must be one of the follow candidates
        - openpose: Is consists of 18 joints. For more information, please
            refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
        - ntu-rgb+d: Is consists of 25 joints. For more information, please
            refer to https://github.com/shahroudy/NTURGB-D

        max_hop (int): the maximal distance between two connected nodes
        dilation (int): controls the spacing between the kernel points

    """

    def __init__(self, layout='custom', strategy='uniform', max_hop=1, dilation=1):
        self.max_hop = max_hop
        self.dilation = dilation

        self.get_edge(layout)
        self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop)
        self.get_adjacency(strategy)

    def __str__(self):
        return self.A

    def get_edge(self, layout):
        # 'body', 'left', 'right'
        # if layout == 'custom_hand21':
        if layout == 'left' or layout == 'right':
            self.num_node = 21
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_1base = [
                [0, 1],[1, 2],[2, 3],[3, 4],
                [0, 5],[5, 6], [6, 7], [7, 8],
                [0, 9], [9, 10],  [10, 11], [11, 12],
                [0, 13], [13, 14], [14, 15], [15, 16],
                [0, 17], [17, 18], [18, 19],  [19, 20],
            ]
            neighbor_link = neighbor_1base
            self.edge = self_link + neighbor_link
            self.center = 0
        
        elif layout == 'body':
            self.num_node = 25
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_1base = [
                [0, 1],[1, 2],[2, 3],[3, 7],
                [0, 4],[4, 5],[5, 6],[6, 8],
                [9, 10],
                [11, 12],
                [11, 13],[13, 15],[15, 21],[15, 19],[15, 17],
                [17, 19],
                [11, 23],
                [12, 14],[14, 16],[16, 18],[16, 20],[16, 22],
                [18, 20],
                [12, 24],
                [23, 24]
            ]
            neighbor_link = neighbor_1base
            self.edge = self_link + neighbor_link
            self.center = 0

    def get_adjacency(self, strategy):
        valid_hop = range(0, self.max_hop + 1, self.dilation)
        adjacency = np.zeros((self.num_node, self.num_node))
        for hop in valid_hop:
            adjacency[self.hop_dis == hop] = 1
        normalize_adjacency = normalize_digraph(adjacency)

        if strategy == 'uniform':
            A = np.zeros((1, self.num_node, self.num_node))
            A[0] = normalize_adjacency
            self.A = A
        elif strategy == 'distance':
            A = np.zeros((len(valid_hop), self.num_node, self.num_node))
            for i, hop in enumerate(valid_hop):
                A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop]
            self.A = A
        elif strategy == 'spatial':
            A = []
            for hop in valid_hop:
                a_root = np.zeros((self.num_node, self.num_node))
                a_close = np.zeros((self.num_node, self.num_node))
                a_further = np.zeros((self.num_node, self.num_node))
                for i in range(self.num_node):
                    for j in range(self.num_node):
                        if self.hop_dis[j, i] == hop:
                            if (
                                self.hop_dis[j, self.center]
                                == self.hop_dis[i, self.center]
                            ):
                                a_root[j, i] = normalize_adjacency[j, i]
                            elif (
                                self.hop_dis[j, self.center]
                                > self.hop_dis[i, self.center]
                            ):
                                a_close[j, i] = normalize_adjacency[j, i]
                            else:
                                a_further[j, i] = normalize_adjacency[j, i]
                if hop == 0:
                    A.append(a_root)
                else:
                    A.append(a_root + a_close)
                    A.append(a_further)
            A = np.stack(A)
            self.A = A
        else:
            raise ValueError("Do Not Exist This Strategy")


def get_hop_distance(num_node, edge, max_hop=1):
    A = np.zeros((num_node, num_node))
    for i, j in edge:
        A[j, i] = 1
        A[i, j] = 1

    # compute hop steps
    hop_dis = np.zeros((num_node, num_node)) + np.inf
    transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
    arrive_mat = np.stack(transfer_mat) > 0
    for d in range(max_hop, -1, -1):
        hop_dis[arrive_mat[d]] = d
    return hop_dis


def normalize_digraph(A):
    Dl = np.sum(A, 0)
    num_node = A.shape[0]
    Dn = np.zeros((num_node, num_node))
    for i in range(num_node):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i] ** (-1)
    AD = np.dot(A, Dn)
    return AD

In [4]:
class GCN_unit(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        A,
        adaptive=True,
        t_kernel_size=1,
        t_stride=1,
        t_padding=0,
        t_dilation=1,
        bias=True,
    ):
        super().__init__()
        self.kernel_size = kernel_size
        assert A.size(0) == self.kernel_size
        self.conv = nn.Conv2d(
            in_channels,
            out_channels * kernel_size,
            kernel_size=(t_kernel_size, 1),
            padding=(t_padding, 0),
            stride=(t_stride, 1),
            dilation=(t_dilation, 1),
            bias=bias,
        )
        self.adaptive = adaptive
        # print(self.adaptive)
        if self.adaptive:
            self.A = nn.Parameter(A.clone())
        else:
            self.register_buffer('A', A)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, len_x):
        x = self.conv(x)

        n, kc, t, v = x.size()
        x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
        x = torch.einsum('nkctv,kvw->nctw', (x, self.A)).contiguous()
        y = self.bn(x)
        y = self.relu(y)
        return y

class STGCN_block(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        A,
        adaptive=True,
        stride=1,
        dropout=0,
        residual=True,
    ):
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = ((kernel_size[0] - 1) // 2, 0)
        self.gcn = GCN_unit(
            in_channels,
            out_channels,
            kernel_size[1],
            A,
            adaptive=adaptive,
        )
        if kernel_size[0] > 1:
            self.tcn = nn.Sequential(
                nn.Conv2d(
                    out_channels,
                    out_channels,
                    (kernel_size[0], 1),
                    (stride, 1),
                    padding,
                ),
                nn.BatchNorm2d(out_channels),
                nn.Dropout(dropout, inplace=True),
            )
        else:
            self.tcn = nn.Identity()

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(stride, 1)),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, len_x=None):
        res = self.residual(x)
        x = self.gcn(x, len_x)
        x = self.tcn(x) + res
        return self.relu(x)

class STGCNChain(nn.Sequential):
    def __init__(self, in_dim, block_args, kernel_size, A, adaptive):
        super(STGCNChain, self).__init__()
        last_dim = in_dim
        for i, [channel, depth] in enumerate(block_args):
            for j in range(depth):
                self.add_module(f'layer{i}_{j}', STGCN_block(last_dim, channel, kernel_size, A.clone(), adaptive))
                last_dim = channel

def get_stgcn_chain(in_dim, level, kernel_size, A, adaptive):
    if level == 'spatial':
        block_args = [[64,1], [128,1], [256,1]]
    elif level == 'temporal':
        block_args = [[256,3]]
    else:
        raise NotImplementedError
    return STGCNChain(in_dim, block_args, kernel_size, A, adaptive), block_args[-1][0]

# Dataloader

Keypoint (T = 64, N_keypoint = 21/25, 3 _(x, y, z))

image (112, 112, 3)

In [5]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset

class SignKeypointDataset(Dataset):
    def __init__(self, data_root, video_list_txt, part='body'):
        """
        Args:
            data_root (str): Đường dẫn thư mục chứa các thư mục video_id.
            video_list_txt (str): File .txt chứa danh sách video_id cần dùng.
            part (str): 'body', 'left', hoặc 'right' để pretrain riêng.
        """
        assert part in ['body', 'left', 'right'], "part must be 'body', 'left', or 'right'"
        self.data_root = data_root
        self.part = part

        with open(video_list_txt, 'r') as f:
            self.video_ids = [line.strip() for line in f]

        self.paths = [os.path.join(data_root, vid, "keypoints.npy") for vid in self.video_ids]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        kp_path = self.paths[idx]
        keypoints = np.load(kp_path)  # shape: (T, 67, 3)

        # Split
        if self.part == 'body':
            part_kp = keypoints[:, :25, :]  # (T, 25, 3)
        elif self.part == 'left':
            part_kp = keypoints[:, 25:46, :]  # (T, 21, 3)
        elif self.part == 'right':
            part_kp = keypoints[:, 46:, :]  # (T, 21, 3)

        return torch.tensor(part_kp, dtype=torch.float32)


# PGF Module

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

In [7]:
class DeformableAttention2D(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, dropout=0.1, offset_groups=8):
        super().__init__()
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5
        self.offset_groups = offset_groups

        inner_dim = dim_head * heads
        self.to_q = nn.Conv1d(dim, inner_dim, 1, bias=False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)

        self.to_offsets = nn.Sequential(
            nn.Conv1d(dim, dim, 1),
            nn.ReLU(),
            nn.Conv1d(dim, offset_groups * 2, 1)
        )

        self.proj = nn.Conv1d(inner_dim, dim, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query_feat, context_feat, ref_points):
        """
        query_feat: (B, C, N) - from pose encoder
        context_feat: (B, C, H, W) - from vision encoder
        ref_points: (B, N, 2) - normalized [0, 1] reference points
        """
        B, C, N = query_feat.shape
        _, _, H, W = context_feat.shape

        # Project Q, K, V
        q = self.to_q(query_feat)  # (B, heads*dim_head, N)
        q = q.view(B, self.heads, self.dim_head, N).permute(0, 1, 3, 2)  # (B, heads, N, dim_head)

        kv = self.to_kv(context_feat)  # (B, 2*heads*dim_head, H, W)
        kv = kv.view(B, 2, self.heads, self.dim_head, H, W)
        k, v = kv[:, 0], kv[:, 1]  # Each: (B, heads, dim_head, H, W)

        # Generate sampling offsets (B, offset_groups*2, N)
        offsets = self.to_offsets(query_feat)  # (B, offset_groups*2, N)
        offsets = offsets.view(B, self.offset_groups, 2, N).permute(0, 1, 3, 2)  # (B, G, N, 2)

        # Add offsets to ref_points
        ref_points = ref_points.unsqueeze(1).repeat(1, self.offset_groups, 1, 1)  # (B, G, N, 2)
        coords = ref_points + offsets / torch.tensor([W, H], device=query_feat.device)  # scale offset
        coords = coords.clamp(0, 1)

        # Reshape coords for sampling
        coords = coords.view(B, self.offset_groups * N, 2)  # (B, G*N, 2)
        coords = coords * 2 - 1  # scale to [-1, 1] for grid_sample
        coords = coords.view(B, 1, self.offset_groups * N, 1, 2)  # (B, 1, GN, 1, 2)

        # Sample K and V with bilinear interpolation
        k = k.view(B * self.heads, self.dim_head, H, W)
        v = v.view(B * self.heads, self.dim_head, H, W)

        k_sampled = F.grid_sample(k, coords.expand(-1, self.dim_head, -1, -1, -1), align_corners=True)
        v_sampled = F.grid_sample(v, coords.expand(-1, self.dim_head, -1, -1, -1), align_corners=True)
        k_sampled = k_sampled.squeeze(-1).view(B, self.heads, self.dim_head, N, self.offset_groups)
        v_sampled = v_sampled.squeeze(-1).view(B, self.heads, self.dim_head, N, self.offset_groups)

        # QK attention
        q = q.unsqueeze(-1)  # (B, heads, N, dim_head, 1)
        attn = (q * k_sampled).sum(3) * self.scale  # (B, heads, N, G)
        attn = F.softmax(attn, dim=-1)

        # Apply attention to V
        out = (attn.unsqueeze(3) * v_sampled).sum(-1)  # (B, heads, dim_head, N)
        out = out.permute(0, 1, 3, 2).contiguous().view(B, -1, N)  # (B, heads*dim_head, N)

        out = self.dropout(self.proj(out))  # (B, C, N)
        return out


In [8]:
class PGFModule(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8):
        super(PGFModule, self).__init__()
        self.dim = dim
        self.heads = heads

        # Gater module để tạo attention weights giữa 2 stream
        self.gater = nn.Sequential(
            nn.Conv1d(dim * 2, dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(dim, dim, kernel_size=1),
            nn.Sigmoid()
        )

        # Multi-head self-attention (cho pose và vision fusion cơ bản)
        self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)

        # Deformable Attention 2D (lấy từ Uni-Sign)
        self.deform_attn = DeformableAttention2D(
            dim=dim,
            dim_head=dim_head,
            heads=heads,
            dropout=0.1,
            offset_groups=heads,
        )

    def forward(self, Fp, Fr, J):
        """
        Args:
            Fp: Pose feature (B, C, N) - từ pose encoder
            Fr: Vision feature (B, C, H, W) - từ vision encoder
            J: Prior keypoint hand (B, N, 2) - tọa độ tay sau khi scale về [0,1]

        Returns:
            F_fused: output feature đã fusion (B, C, N)
        """

        B, C, N = Fp.shape

        # 1. Tính attention giữa pose và vision (multi-head attention)
        # Fp -> (B, N, C), Fr -> (B, HW, C)
        Fp_ = rearrange(Fp, 'b c n -> b n c')
        Fr_ = rearrange(Fr, 'b c h w -> b (h w) c')

        # Self-attention trên pose + vision
        attn_output, _ = self.self_attn(Fp_, Fr_, Fr_)
        attn_output = rearrange(attn_output, 'b n c -> b c n')

        # 2. Fusion với Gater module
        gater_input = torch.cat([Fp, attn_output], dim=1)  # (B, 2C, N)
        gate = self.gater(gater_input)  # (B, C, N)

        # Gated fusion pose stream
        Fp_fused = Fp * gate + attn_output * (1 - gate)  # (B, C, N)

        # 3. Deformable attention với prior position J (B, N, 2)
        F_fused = self.deform_attn(Fp_fused, Fr, J)  # (B, C, N)

        return F_fused


In [9]:
class RGBFusion(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.backbone = nn.Sequential(*list(torchvision.models.efficientnet_b0(pretrained=True).children())[:-2])
        self.rgb_proj = nn.Conv2d(1280, hidden_dim, kernel_size=1)

        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, batch_first=True)
        self.deform_attn = DeformableAttention2D(
            dim=hidden_dim, dim_head=32, heads=8, dropout=0.,
            downsample_factor=1, offset_kernel_size=1
        )
        self.gate = nn.Sequential(
            nn.Conv1d(hidden_dim * 2, hidden_dim, 1),
            nn.GELU(),
            nn.Conv1d(hidden_dim, hidden_dim, 1),
            nn.Tanh(),
            nn.ReLU()
        )
        for layer in self.gate:
            if isinstance(layer, nn.Conv1d):
                nn.init.constant_(layer.weight, 0)
                nn.init.constant_(layer.bias, 0)

    def forward(self, query_feat, rgb_feat, indices, rgb_len, pose_coords):
        """
        query_feat: [B, C, T, N] (from ST-GCN)
        rgb_feat: [sum(L), 1280, H, W] (raw backbone output for all frames)
        pose_coords: [sum(L), N, 2] (reference for DA2D)
        indices: frame indices per video
        rgb_len: frame count per video
        """
        b, c, T, n = query_feat.shape
        rgb_feat = self.rgb_proj(rgb_feat)  # -> [sum(L), C, H, W]
        start = 0

        for batch in range(b):
            frame_idx = indices[start:start + rgb_len[batch]].to(torch.long)
            if rgb_len[batch] == 1 and -1 in frame_idx:
                start += rgb_len[batch]
                continue

            # select
            query = query_feat[batch, :, frame_idx]  # [C, L, N]
            query = rearrange(query, 'c t n -> t n c')  # [L, N, C]

            rgb_patch = rgb_feat[start:start + rgb_len[batch]]  # [L, C, H, W]
            coords = pose_coords[start:start + rgb_len[batch]]  # [L, N, 2]

            # DA2D over image features using joints
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                kv_feat = self.deform_attn(query.permute(1, 0, 2), rgb_patch, coords)  # [N, L, C]

            kv_feat = kv_feat.permute(1, 0, 2)  # [L, N, C]
            query = query  # [L, N, C]
            fusion_input = torch.cat([kv_feat, query], dim=-1).permute(1, 2, 0)  # [N, 2C, L]
            gate_score = self.gate(fusion_input)  # [N, C, L]
            fused = gate_score * kv_feat.permute(1, 2, 0) + (1 - gate_score) * query.permute(1, 2, 0)  # [N, C, L]
            query_feat[batch, :, frame_idx] = fused.permute(1, 2, 0)  # back to [C, T, N]

            start += rgb_len[batch]

        return query_feat


In [10]:
# Từ keypoint (T, N, 3) linear sang (T, N, 64) để encoder sâu hơn

class PoseBranch(nn.Module):
    def __init__(self, A, input_dim=3, embed_dim=64, hidden_dim=256):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)
        self.spatial_gcn, final_dim = get_stgcn_chain(embed_dim, 'spatial', (1, A.size(0)), A.clone(), True)
        self.temporal_gcn, _ = get_stgcn_chain(final_dim, 'temporal', (5, A.size(0)), A.clone(), True)
        self.out_proj = nn.Linear(final_dim, hidden_dim)

    def forward(self, x):
        x = self.proj(x).permute(0, 3, 1, 2)
        x = self.spatial_gcn(x)
        x = self.temporal_gcn(x)
        x = x.mean(-1).transpose(1, 2)
        return self.out_proj(x)

# Train/Test split

import os
import json
import random
from collections import defaultdict

> Cố định seed để chia reproducible
random.seed(42)

> Đường dẫn
json_path = "/kaggle/input/wlasl-processed/WLASL_v0.3.json"
data_dir = "/kaggle/input/data-wlasl/DATA"

> Lấy các video_id (folder name) đã có trong DATA/
used_video_ids = set(os.listdir(data_dir))

> Bước 1: Nhóm video_id theo nhãn (chỉ giữ những video có trong DATA/)
with open(json_path, 'r') as f:
    data = json.load(f)

label_to_videos = defaultdict(list)
video_to_label = dict()

for entry in data:
    label = entry['gloss']
    for instance in entry['instances']:
        video_id = instance['video_id']
        if video_id in used_video_ids:
            label_to_videos[label].append(video_id)
            video_to_label[video_id] = label

> Bước 2: Chia theo 6:2:2 cho từng nhãn
train_list, val_list, test_list = [], [], []

for label, videos in label_to_videos.items():
    if len(videos) < 3:
        continue  # Không đủ để chia

    random.shuffle(videos)
    n = len(videos)
    n_test = int(n * 0.2)
    n_val = int(n * 0.2)

    test_list.extend(videos[:n_test])
    val_list.extend(videos[n_test:n_test+n_val])
    train_list.extend(videos[n_test+n_val:])
> Bước 3: Lưu vào file (mỗi dòng: video_id label)
def save_list(path, video_ids, video_to_label):
    with open(path, "w") as f:
        for vid in video_ids:
            label = video_to_label[vid]
            f.write(f"{vid} {label}\n")

save_list("/kaggle/working/train.txt", train_list, video_to_label)
save_list("/kaggle/working/val.txt", val_list, video_to_label)
save_list("/kaggle/working/test.txt", test_list, video_to_label)

print(f"[Done] Train: {len(train_list)} | Val: {len(val_list)} | Test: {len(test_list)}")


# Main

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torchvision

class Classify_Sign(nn.Module):
    def __init__(self, args):
        super(Uni_Sign, self).__init__()
        self.args = args

        self.modes = ['body', 'left', 'right']

        self.graph, A = {}, []
        hidden_dim = args.hidden_dim
        self.proj_linear = nn.ModuleDict()
        for mode in self.modes:
            self.graph[mode] = Graph(layout=mode, strategy='distance', max_hop=1)
            A.append(torch.tensor(self.graph[mode].A, dtype=torch.float32, requires_grad=False))
            self.proj_linear[mode] = nn.Linear(3, 64)

        self.gcn_modules = nn.ModuleDict()
        self.fusion_gcn_modules = nn.ModuleDict()
        spatial_kernel_size = A[0].size(0)
        for index, mode in enumerate(self.modes):
            self.gcn_modules[mode], final_dim = get_stgcn_chain(64, 'spatial', (1, spatial_kernel_size), A[index].clone(), True)
            self.fusion_gcn_modules[mode], _ = get_stgcn_chain(final_dim, 'temporal', (5, spatial_kernel_size), A[index].clone(), True)

        self.proj_pose_body = nn.Linear(final_dim, hidden_dim)
        self.proj_pose_left = nn.Linear(final_dim, hidden_dim)
        self.proj_pose_right = nn.Linear(final_dim, hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, args.num_classes)
        )

        if self.args.rgb_support:
            self.rgb_support_backbone = torch.nn.Sequential(
                *list(torchvision.models.efficientnet_b0(pretrained=True).children())[:-2]
            )
            self.rgb_proj = nn.Conv2d(1280, hidden_dim, kernel_size=1)

            self.fusion_pose_rgb_linear = nn.Linear(hidden_dim, hidden_dim)

            self.fusion_pose_rgb_DA = DeformableAttention2D(
                dim=hidden_dim, dim_head=32, heads=8,
                dropout=0., downsample_factor=1,
                offset_kernel_size=1
            )

            self.fusion_gate = nn.Sequential(
                nn.Conv1d(hidden_dim * 2, hidden_dim, 1),
                nn.GELU(),
                nn.Conv1d(hidden_dim, 1, 1),
                nn.Tanh(),
                nn.ReLU()
            )
            for layer in self.fusion_gate:
                if isinstance(layer, nn.Conv1d):
                    nn.init.constant_(layer.weight, 0)
                    nn.init.constant_(layer.bias, 0)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def maybe_autocast(self, dtype=torch.float32):
        return torch.cuda.amp.autocast(dtype=dtype) if torch.cuda.is_available() else contextlib.nullcontext()

    def gather_feat_pose_rgb(self, gcn_feat, rgb_feat, indices, rgb_len, pose_init):
        b, c, T, n = gcn_feat.shape
        rgb_feat = self.rgb_proj(rgb_feat)
        start = 0
        for batch in range(b):
            index = indices[start:start + rgb_len[batch]].to(torch.long)
            if rgb_len[batch] == 1 and -1 in index:
                start += rgb_len[batch]
                continue
            gcn_selected = gcn_feat[batch, :, index]
            rgb_selected = rgb_feat[start:start + rgb_len[batch]]
            pose_init_selected = pose_init[start:start + rgb_len[batch]]
            gcn_selected = rearrange(gcn_selected, 'c t n -> t c n')
            pose_init_selected = rearrange(pose_init_selected, 't n c -> t c n')
            with self.maybe_autocast():
                fused = self.fusion_pose_rgb_DA(gcn_selected, rgb_selected, pose_init_selected)
            fused = fused.to(gcn_feat.dtype)
            gate_input = torch.cat([fused, gcn_selected], dim=-2)
            gate_score = self.fusion_gate(gate_input)
            fused = gate_score * fused + (1 - gate_score) * gcn_selected
            gcn_feat[batch, :, index] = rearrange(fused, 't c n -> c t n')
            start += rgb_len[batch]
        return gcn_feat

    def forward(self, src_input):
        if self.args.rgb_support:
            rgb_support = {}
            for part in ['left', 'right']:
                rgb_feat = self.rgb_support_backbone(src_input[f'{part}_hands'])
                rgb_support[f'{part}_hands'] = rgb_feat
                rgb_support[f'{part}_sampled_indices'] = src_input[f'{part}_sampled_indices']

        part_feats = {}
        body_feat = None

        for part in self.modes:
            proj_feat = self.proj_linear[part](src_input[part]).permute(0, 3, 1, 2)
            gcn_feat = self.gcn_modules[part](proj_feat)

            if part == 'body':
                body_feat = gcn_feat
            else:
                if part in ['left', 'right']:
                    if self.args.rgb_support:
                        gcn_feat = self.gather_feat_pose_rgb(
                            gcn_feat,
                            rgb_support[f'{part}_hands'],
                            rgb_support[f'{part}_sampled_indices'],
                            src_input[f'{part}_rgb_len'],
                            src_input[f'{part}_skeletons_norm']
                        )
                    offset = -2 if part == 'left' else -1
                    gcn_feat = gcn_feat + body_feat[..., offset][..., None].detach()
                else:
                    raise NotImplementedError

            gcn_feat = self.fusion_gcn_modules[part](gcn_feat)
            pooled_feat = gcn_feat.mean(-1).transpose(1, 2)  # [B, T, C]

            if part == 'body':
                part_feats['body'] = self.proj_pose_body(pooled_feat)
            elif part == 'left':
                part_feats['left'] = self.proj_pose_left(pooled_feat)
            elif part == 'right':
                part_feats['right'] = self.proj_pose_right(pooled_feat)

        feat_body = part_feats['body'].mean(dim=1)  # [B, hidden_dim]
        feat_left = part_feats['left'].mean(dim=1)
        feat_right = part_feats['right'].mean(dim=1)

        feat_all = torch.cat([feat_body, feat_left, feat_right], dim=-1)
        out = self.classifier(feat_all)
        return out


import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from tqdm import tqdm
from stgcn import get_stgcn_chain  # assumed your ST-GCN code is in stgcn.py
from dataset_pose import PoseDataset  # custom dataset loading keypoints
from dataset_rgb import RGBHandDataset  # custom dataset loading crop hand images

# ----------- CONFIG -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_root = "Data"
label_file = "class_name.txt"
train_list = "train.txt"
num_classes = 100  # or however many classes you have
pose_parts = ["body", "left", "right"]
model_save_path = "pretrained_models"
os.makedirs(model_save_path, exist_ok=True)

# ----------- ST-GCN TRAINING UTILS -----------
def train_one_epoch_stgcn(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    for keypoints, labels in tqdm(dataloader):
        keypoints, labels = keypoints.to(device), labels.to(device)
        out = model(keypoints)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * keypoints.size(0)
        total_correct += (out.argmax(1) == labels).sum().item()

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

# ----------- RGB TRAINING UTILS -----------
def train_one_epoch_rgb(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_correct = 0, 0
    for imgs, labels in tqdm(dataloader):
        imgs, labels = imgs.to(device), labels.to(device)
        out = model(imgs)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_correct += (out.argmax(1) == labels).sum().item()

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

# ----------- PRETRAIN ST-GCN -----------
def pretrain_stgcn():
    for part in pose_parts:
        print(f"\n[INFO] Pretraining ST-GCN on {part} keypoints")
        graph = Graph(layout=part)
        A = torch.tensor(graph.A, dtype=torch.float32)
        model, _ = get_stgcn_chain(in_dim=2, level='spatial', kernel_size=(9, A.size(0)), A=A, adaptive=True)
        model = nn.Sequential(model, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(256, num_classes)).to(device)

        dataset = PoseDataset(data_root, train_list, part=part)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(10):
            loss, acc = train_one_epoch_stgcn(model, dataloader, optimizer, criterion)
            print(f"Epoch {epoch+1}: Loss = {loss:.4f}, Acc = {acc*100:.2f}%")

        torch.save(model.state_dict(), f"{model_save_path}/stgcn_{part}.pth")
        print(f"[SAVED] ST-GCN {part} to stgcn_{part}.pth")

# ----------- PRETRAIN VISION ENCODER -----------
def pretrain_rgb_encoder():
    print("\n[INFO] Pretraining Vision Encoder on cropped RGB hands")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    for side in ["left", "right"]:
        print(f"\n[INFO] Pretraining for {side} hand")
        dataset = RGBHandDataset(data_root, train_list, side=side, transform=transform)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

        model = models.efficientnet_b0(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        model = model.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(10):
            loss, acc = train_one_epoch_rgb(model, dataloader, optimizer, criterion)
            print(f"Epoch {epoch+1}: Loss = {loss:.4f}, Acc = {acc*100:.2f}%")

        torch.save(model.state_dict(), f"{model_save_path}/efficientnet_{side}.pth")
        print(f"[SAVED] EfficientNet {side} to efficientnet_{side}.pth")

# ----------- MAIN -----------
if __name__ == '__main__':
    pretrain_stgcn()
    pretrain_rgb_encoder()


# Pretrain RGB

In [12]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

class HandRGBDataset(Dataset):
    def __init__(self, label_file, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        # Load labels
        self.data = []
        with open(label_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    video_id = parts[0]
                    label = " ".join(parts[1:])  # hỗ trợ label có khoảng trắng
                    for vid_id in os.listdir(data_dir):
                        for fname in os.listdir(os.path.join(data_dir, vid_id)):
                            if fname.startswith(video_id):
                                self.data.append((os.path.join(data_dir, vid_id, fname), label))

        # Encode label
        self.classes = sorted(list(set(label for _, label in self.data)))
        self.label2idx = {label: idx for idx, label in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        label_idx = self.label2idx[label]
        return img, label_idx


Numbers of label:
Videos each label: 8-16

Data samples
- Train: 2341
- Validation: 465
- Test: 465

In [13]:
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

# Dataset & Transform
image_size = 112
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = HandRGBDataset(
    label_file='/kaggle/input/train-test-valid/train.txt',
    data_dir='/kaggle/input/data-wlasl/DATA',
    transform=transform
)
val_dataset = HandRGBDataset(
    label_file='/kaggle/input/train-test-valid/val.txt',
    data_dir='/kaggle/input/data-wlasl/DATA',
    transform=transform
)

In [14]:
# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=4)



In [15]:
num_classes = len(train_dataset.classes)
model = models.efficientnet_b0(weights='IMAGENET1K_V1')
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 157MB/s]


In [16]:
EPOCHS = 20
best_val_acc = 0.0
best_model_path = "/kaggle/working/best_efficientnet_rgb_pretrain.pth"

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0

    for imgs, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * imgs.size(0)
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()

    train_avg_loss = train_loss / train_total
    train_acc = train_correct / train_total

    # Validation phase
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{EPOCHS}"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_avg_loss = val_loss / val_total
    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_avg_loss:.4f}, Train Acc={train_acc:.4f} | "
          f"Val Loss={val_avg_loss:.4f}, Val Acc={val_acc:.4f}")

    # Lưu model tốt nhất
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'label2idx': train_dataset.label2idx,
            'classes': train_dataset.classes,
            'epoch': epoch + 1,
            'val_acc': val_acc,
        }, best_model_path)
        print(f"✅ Saved best model (Epoch {epoch+1}, Val Acc: {val_acc:.4f})")


Training Epoch 1/20: 100%|██████████| 404/404 [00:25<00:00, 15.90it/s]
Validation Epoch 1/20: 100%|██████████| 320/320 [00:04<00:00, 68.05it/s]


Epoch 1: Train Loss=5.3023, Train Acc=0.0631 | Val Loss=5.5614, Val Acc=0.0500
✅ Saved best model (Epoch 1, Val Acc: 0.0500)


Training Epoch 2/20: 100%|██████████| 404/404 [00:22<00:00, 17.69it/s]
Validation Epoch 2/20: 100%|██████████| 320/320 [00:04<00:00, 70.52it/s]


Epoch 2: Train Loss=3.6943, Train Acc=0.2604 | Val Loss=5.3495, Val Acc=0.0735
✅ Saved best model (Epoch 2, Val Acc: 0.0735)


Training Epoch 3/20: 100%|██████████| 404/404 [00:22<00:00, 17.63it/s]
Validation Epoch 3/20: 100%|██████████| 320/320 [00:04<00:00, 70.90it/s]


Epoch 3: Train Loss=2.6125, Train Acc=0.4417 | Val Loss=5.4563, Val Acc=0.0742
✅ Saved best model (Epoch 3, Val Acc: 0.0742)


Training Epoch 4/20: 100%|██████████| 404/404 [00:23<00:00, 17.23it/s]
Validation Epoch 4/20: 100%|██████████| 320/320 [00:04<00:00, 73.83it/s]


Epoch 4: Train Loss=1.8433, Train Acc=0.5929 | Val Loss=5.5849, Val Acc=0.0942
✅ Saved best model (Epoch 4, Val Acc: 0.0942)


Training Epoch 5/20: 100%|██████████| 404/404 [00:23<00:00, 16.93it/s]
Validation Epoch 5/20: 100%|██████████| 320/320 [00:04<00:00, 73.02it/s]


Epoch 5: Train Loss=1.2603, Train Acc=0.7175 | Val Loss=5.9976, Val Acc=0.0922


Training Epoch 6/20: 100%|██████████| 404/404 [00:24<00:00, 16.66it/s]
Validation Epoch 6/20: 100%|██████████| 320/320 [00:04<00:00, 71.52it/s]


Epoch 6: Train Loss=0.8388, Train Acc=0.8110 | Val Loss=6.2146, Val Acc=0.0930


Training Epoch 7/20: 100%|██████████| 404/404 [00:24<00:00, 16.48it/s]
Validation Epoch 7/20: 100%|██████████| 320/320 [00:04<00:00, 72.76it/s]


Epoch 7: Train Loss=0.5653, Train Acc=0.8711 | Val Loss=6.5565, Val Acc=0.0950
✅ Saved best model (Epoch 7, Val Acc: 0.0950)


Training Epoch 8/20: 100%|██████████| 404/404 [00:24<00:00, 16.68it/s]
Validation Epoch 8/20: 100%|██████████| 320/320 [00:04<00:00, 73.51it/s]


Epoch 8: Train Loss=0.3949, Train Acc=0.9081 | Val Loss=6.7992, Val Acc=0.0993
✅ Saved best model (Epoch 8, Val Acc: 0.0993)


Training Epoch 9/20: 100%|██████████| 404/404 [00:24<00:00, 16.70it/s]
Validation Epoch 9/20: 100%|██████████| 320/320 [00:04<00:00, 72.07it/s]


Epoch 9: Train Loss=0.3032, Train Acc=0.9279 | Val Loss=6.9644, Val Acc=0.0934


Training Epoch 10/20: 100%|██████████| 404/404 [00:24<00:00, 16.46it/s]
Validation Epoch 10/20: 100%|██████████| 320/320 [00:04<00:00, 72.15it/s]


Epoch 10: Train Loss=0.2422, Train Acc=0.9400 | Val Loss=7.2150, Val Acc=0.0930


Training Epoch 11/20: 100%|██████████| 404/404 [00:24<00:00, 16.62it/s]
Validation Epoch 11/20: 100%|██████████| 320/320 [00:04<00:00, 71.51it/s]


Epoch 11: Train Loss=0.1929, Train Acc=0.9499 | Val Loss=7.4496, Val Acc=0.0953


Training Epoch 12/20: 100%|██████████| 404/404 [00:24<00:00, 16.53it/s]
Validation Epoch 12/20: 100%|██████████| 320/320 [00:04<00:00, 68.60it/s]


Epoch 12: Train Loss=0.1988, Train Acc=0.9446 | Val Loss=7.4286, Val Acc=0.0879


Training Epoch 13/20: 100%|██████████| 404/404 [00:24<00:00, 16.62it/s]
Validation Epoch 13/20: 100%|██████████| 320/320 [00:04<00:00, 73.11it/s]


Epoch 13: Train Loss=0.1863, Train Acc=0.9472 | Val Loss=7.6679, Val Acc=0.0957


Training Epoch 14/20: 100%|██████████| 404/404 [00:24<00:00, 16.52it/s]
Validation Epoch 14/20: 100%|██████████| 320/320 [00:04<00:00, 72.03it/s]


Epoch 14: Train Loss=0.1678, Train Acc=0.9520 | Val Loss=7.8634, Val Acc=0.0973


Training Epoch 15/20: 100%|██████████| 404/404 [00:24<00:00, 16.58it/s]
Validation Epoch 15/20: 100%|██████████| 320/320 [00:04<00:00, 74.15it/s]


Epoch 15: Train Loss=0.1705, Train Acc=0.9484 | Val Loss=8.2023, Val Acc=0.1000
✅ Saved best model (Epoch 15, Val Acc: 0.1000)


Training Epoch 16/20: 100%|██████████| 404/404 [00:24<00:00, 16.63it/s]
Validation Epoch 16/20: 100%|██████████| 320/320 [00:04<00:00, 75.67it/s]


Epoch 16: Train Loss=0.1724, Train Acc=0.9452 | Val Loss=8.2419, Val Acc=0.0946


Training Epoch 17/20: 100%|██████████| 404/404 [00:24<00:00, 16.52it/s]
Validation Epoch 17/20: 100%|██████████| 320/320 [00:04<00:00, 73.58it/s]


Epoch 17: Train Loss=0.1567, Train Acc=0.9509 | Val Loss=8.3370, Val Acc=0.0903


Training Epoch 18/20: 100%|██████████| 404/404 [00:24<00:00, 16.64it/s]
Validation Epoch 18/20: 100%|██████████| 320/320 [00:04<00:00, 73.11it/s]


Epoch 18: Train Loss=0.1541, Train Acc=0.9498 | Val Loss=8.6505, Val Acc=0.0911


Training Epoch 19/20: 100%|██████████| 404/404 [00:24<00:00, 16.62it/s]
Validation Epoch 19/20: 100%|██████████| 320/320 [00:04<00:00, 72.00it/s]


Epoch 19: Train Loss=0.1382, Train Acc=0.9555 | Val Loss=8.5736, Val Acc=0.0891


Training Epoch 20/20: 100%|██████████| 404/404 [00:24<00:00, 16.69it/s]
Validation Epoch 20/20: 100%|██████████| 320/320 [00:04<00:00, 71.05it/s]

Epoch 20: Train Loss=0.1501, Train Acc=0.9485 | Val Loss=8.8041, Val Acc=0.0903





from torchvision import models

# Tải lại kiến trúc EfficientNet
model = models.efficientnet_b0(weights=None)  # Không cần pre-trained weights

# Load checkpoint đã lưu trước đó
checkpoint = torch.load("/kaggle/working/best_efficientnet_rgb_pretrain.pth", map_location=device)

# Chỉ load phần feature extractor
model.features.load_state_dict(checkpoint['model_state_dict'], strict=False)
