In [None]:
import os
import pickle

base_directories = [
    # put path of pantomime/primary_exp/office(and open)/1/0/normal   
]

def load_all_pkl_files(base_dirs):
    pkl_files = {}  
    for base_dir in base_dirs:
        if not os.path.exists(base_dir):
            print(f"Warning: {base_dir} does not exist.")
            continue
        
        for user_id in os.listdir(base_dir):
            user_path = os.path.join(base_dir, user_id)
            if os.path.isdir(user_path): 
                if user_id not in pkl_files:
                    pkl_files[user_id] = {}
                
                for gesture_id in os.listdir(user_path):
                    gesture_path = os.path.join(user_path, gesture_id)
                    if os.path.isdir(gesture_path):  
                        if gesture_id not in pkl_files[user_id]:
                            pkl_files[user_id][gesture_id] = []
                        
                        for filename in os.listdir(gesture_path):
                            if filename.endswith(".pkl"):
                                file_path = os.path.join(gesture_path, filename)
                                try:
                                    with open(file_path, "rb") as f:
                                        data = pickle.load(f, encoding="latin1")  
                                        pkl_files[user_id][gesture_id].append(data)
                                except (pickle.UnpicklingError, UnicodeDecodeError) as e:
                                    print(f"Error loading {file_path}: {e}")
    return pkl_files

loaded_data = load_all_pkl_files(base_directories)

for user, gestures in loaded_data.items():
    print(f"User {user}:")
    for gesture, files in gestures.items():
        print(f"  Gesture {gesture}: {len(files)} pkl files loaded")

# labeling
def label_data(data):
    labeled_data = []
    user_label_map = {}  # user_label ‚Üí user_id mapping
    for user_idx, (user_id, gestures) in enumerate(data.items()):
        user_label_map[user_idx] = user_id  
        for gesture_idx, (gesture_id, files) in enumerate(gestures.items()):
            for file_data in files:
                labeled_data.append({
                    "user_label": user_idx,
                    "gesture_label": int(gesture_id)-1,
                    "data": file_data
                })
    return labeled_data, user_label_map

labeled_data, user_label_map = label_data(loaded_data)

print(f"Total labeled samples: {len(labeled_data)}")
num_unique_users = len(set(sample['user_label'] for sample in labeled_data))
print(f"num_unique_useres: {num_unique_users}") 



In [None]:
from collections import Counter, defaultdict

assert "labeled_data" in globals(), "labeled_dataÍ∞Ä ÏóÜÏäµÎãàÎã§. Î®ºÏ†Ä pantomime Î°úÎìú/label_data ÏÖÄÏùÑ Ïã§ÌñâÌïòÏÑ∏Ïöî."
assert isinstance(labeled_data, list) and len(labeled_data) > 0, "labeled_dataÍ∞Ä ÎπÑÏñ¥ÏûàÏäµÎãàÎã§."

gesture_counts = Counter(int(s["gesture_label"]) for s in labeled_data)

print("=== Gesture (action) counts ===")
for g in sorted(gesture_counts):
    print(f"gesture {g}: {gesture_counts[g]}")
print("TOTAL:", sum(gesture_counts.values()), "\n")

user_counts = Counter(int(s["user_label"]) for s in labeled_data)

print("=== User counts ===")
has_map = ("user_label_map" in globals()) and isinstance(user_label_map, dict)
for u in sorted(user_counts):
    if has_map and u in user_label_map:
        print(f"user {u} ({user_label_map[u]}): {user_counts[u]}")
    else:
        print(f"user {u}: {user_counts[u]}")
print("TOTAL:", sum(user_counts.values()), "\n")  # 41 users 210~420, 21 gestures 410~440 samples 

--------

PROPOSED

In [None]:
# GRAPH CONSTRUCTION #

import numpy as np
import torch
import time
import math
from torch_geometric.data import Data
from collections import Counter

# =========================
# Edge weight params
# =========================
DIST_WEIGHT = 4

def positional_encoding_1d(time_steps, max_time_steps):
    return torch.sin(time_steps.clone().detach().float() * (math.pi / max_time_steps)).unsqueeze(1)

# =========================
# Pantomime sample -> frames(list of (Pi,>=3)) normalize
# =========================
def _get_x_from_item(item):
    x = item.get("data", item)
    if isinstance(x, dict):
        for k in ["pos", "points", "pc", "xyz", "coords", "data", "frames"]:
            if k in x:
                x = x[k]
                break
    return x

def to_frames_pointlist(x):
    """
    Return: frames (list of np.ndarray), frames[t] shape (Pi, C>=3), float32
    Supports:
      - list/tuple of frames
      - (T,P,C) tensor/ndarray
      - (N,C) tensor/ndarray -> single frame
      - dict (key search)
    """
    if isinstance(x, dict):
        for k in ["pos", "points", "pc", "xyz", "coords", "data", "frames"]:
            if k in x:
                x = x[k]
                break

    # list of frames
    if isinstance(x, (list, tuple)) and len(x) > 0 and (not torch.is_tensor(x)) and (not isinstance(x, np.ndarray)):
        frames = []
        for fr in x:
            if fr is None:
                frames.append(np.zeros((0, 3), dtype=np.float32))
                continue
            if isinstance(fr, dict):
                for k in ["pos", "points", "pc", "xyz", "coords", "data"]:
                    if k in fr:
                        fr = fr[k]
                        break
            fr = np.asarray(fr, dtype=np.float32)
            if fr.ndim != 2 or fr.shape[1] < 3:
                frames.append(np.zeros((0, 3), dtype=np.float32))
            else:
                frames.append(fr)
        return frames

    x = torch.as_tensor(x, dtype=torch.float32)

    # (T,P,C)
    if x.ndim == 3 and x.size(-1) >= 3:
        return [x[t].detach().cpu().numpy().astype(np.float32) for t in range(x.size(0))]

    # (N,C) -> single frame
    if x.ndim == 2 and x.size(-1) >= 3:
        return [x.detach().cpu().numpy().astype(np.float32)]

    raise ValueError(f"Unsupported x type/shape: type={type(x)} shape={getattr(x, 'shape', None)}")

# =========================
# PyG graph generation 
# =========================
def create_pyg_graph(point_data_frames, time_interval=1/24):
    start_time = time.time()

    global_index = 0
    node_features = []
    time_step_list = []
    edge_index = []
    edge_weight = []
    node_map = {}

    non_empty_steps = [t for t in range(len(point_data_frames)) if len(point_data_frames[t]) > 0]
    max_time_steps = max(non_empty_steps) + 1 if non_empty_steps else 1

    # nodes
    for t in non_empty_steps:
        node_map[t] = {}
        fr = np.asarray(point_data_frames[t], dtype=np.float32)
        for i, p in enumerate(fr):
            node_map[t][i] = global_index
            node_features.append(p[:3])  # XYZ
            time_step_list.append(t)
            global_index += 1

    if len(node_features) == 0:
        empty = Data(
            x=torch.zeros((0, 4), dtype=torch.float32),
            edge_index=torch.zeros((2, 0), dtype=torch.long),
            edge_weight=torch.zeros((0,), dtype=torch.float32),
            time_steps=torch.zeros((0,), dtype=torch.long),
        )
        return empty, time.time() - start_time

    node_features = np.array(node_features, dtype=np.float32)

    # edges between consecutive non-empty frames (all-to-all)
    for step_idx in range(len(non_empty_steps) - 1):
        t = non_empty_steps[step_idx]
        next_t = non_empty_steps[step_idx + 1]

        points_t = np.asarray(point_data_frames[t], dtype=np.float32)
        points_next_t = np.asarray(point_data_frames[next_t], dtype=np.float32)
        if len(points_t) == 0 or len(points_next_t) == 0:
            continue

        distances = np.linalg.norm(points_t[:, :3, np.newaxis] - points_next_t[:, :3].T, axis=1)
        weights = np.exp(-DIST_WEIGHT * distances)

        from_idx, to_idx = np.meshgrid(
            np.arange(len(points_t)), np.arange(len(points_next_t)), indexing="ij"
        )

        from_list = from_idx.flatten().tolist()
        to_list   = to_idx.flatten().tolist()
        w_list    = weights.flatten().tolist()

        for fi, ti, w in zip(from_list, to_list, w_list):
            edge_index.append([node_map[t][fi], node_map[next_t][ti]])
            edge_weight.append(w)

    x = torch.tensor(node_features, dtype=torch.float32)
    time_steps_tensor = torch.tensor(time_step_list, dtype=torch.long)
    pos_enc = positional_encoding_1d(time_steps_tensor, max_time_steps=max_time_steps)
    x = torch.cat([x, pos_enc], dim=1)  # (num_nodes,4)

    edge_index = torch.tensor(edge_index, dtype=torch.long).T if len(edge_index) else torch.zeros((2, 0), dtype=torch.long)
    edge_weight = torch.tensor(edge_weight, dtype=torch.float32) if len(edge_weight) else torch.zeros((0,), dtype=torch.float32)

    return Data(x=x, edge_index=edge_index, edge_weight=edge_weight, time_steps=time_steps_tensor), (time.time() - start_time)


assert "labeled_data" in globals() and isinstance(labeled_data, list) and len(labeled_data) > 0, \
    "no labeld data."

filtered_graphs = []
execution_times = []
filtered_g_labels = []
filtered_u_labels = []

for item in labeled_data:
    x_raw = _get_x_from_item(item)
    frames = to_frames_pointlist(x_raw)

    graph, exec_time = create_pyg_graph(frames)

    action_label = int(item["gesture_label"])  # 0..20
    user_label   = int(item["user_label"])     # 0..40

    graph.y_action = torch.tensor([action_label], dtype=torch.long)
    graph.y_user   = torch.tensor([user_label], dtype=torch.long)

    filtered_graphs.append(graph)
    execution_times.append(exec_time)
    filtered_g_labels.append(action_label)
    filtered_u_labels.append(user_label)

avg_time = float(np.mean(execution_times)) if execution_times else 0.0
print(f"üìå Total graphs: {len(filtered_graphs)}")
print(f"üìå Avg Preprocessing time: {avg_time:.4f} sec")

act_cnt = Counter(filtered_g_labels)
usr_cnt = Counter(filtered_u_labels)
print(f"üìå num act: {len(act_cnt)} | ex: {act_cnt.most_common(5)}")
print(f"üìå num user: {len(usr_cnt)} | ex: {usr_cnt.most_common(5)}")


In [None]:
import torch
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
import math

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_channels, out_channels)
        self.bn1 = nn.LayerNorm(out_channels)
        self.fc2 = nn.Linear(out_channels, out_channels)
        self.bn2 = nn.LayerNorm(out_channels)
        self.relu = nn.ReLU() 

    def forward(self, x):
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.relu(self.bn2(self.fc2(x))) 
        return x


class CustomGINConv(pyg_nn.MessagePassing):
    def __init__(self, nn_model, eps=0, train_eps=False):
        super(CustomGINConv, self).__init__(aggr="add")  
        self.nn = nn_model
        self.eps = torch.nn.Parameter(torch.Tensor([eps])) if train_eps else eps

    def forward(self, x, edge_index, edge_weight):
        src, dst = edge_index 
        weighted_messages = x[src] * edge_weight.view(-1, 1)
        aggregated = torch.zeros_like(x, dtype=torch.float32)
        aggregated.index_add_(0, dst, weighted_messages.to(dtype=torch.float32)) 
        # GIN aggregation
        out = x.to(dtype=torch.float32) + aggregated
        out = self.nn(out)  
        return out.to(x.dtype)  

class ReadoutWithMLP(nn.Module):
    def __init__(self, input_dim, output_dim, num_hidden_layers=2, method="sum"):
        """
        Readout + MLP 
        Args:
            method: "sum", "mean", "max", "sum+mean", "sum+max", "mean+max", "sum+mean+max", "std"
        """
        super(ReadoutWithMLP, self).__init__()
        self.method = method.lower().split("+")  
        self.supported = {'sum', 'mean', 'max', 'std'}
        assert set(self.method).issubset(self.supported), f"not supported Readout: {self.method}"

        self.readout_dim = input_dim * len(self.method)
        ###
        # if num_hidden_layers > 0:
        #     step = (output_dim // self.readout_dim) ** (1 / num_hidden_layers)
        #     hidden_dims = [int(self.readout_dim * (step**i)) for i in range(1, num_hidden_layers)]
        #     hidden_dims.append(output_dim)
        # else:
        #     hidden_dims = [output_dim]
        ###
        hidden_dims = [256, 1024] ####

        layers = []
        prev_dim = self.readout_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        self.mlp = nn.Sequential(*layers)

    def forward(self, output, batch_idx, valid_mask=None):
        """
        Args:
            output: (num_nodes, feature_dim)
            batch_idx: (num_nodes,) 
            valid_mask: (num_nodes,) 
        """
        batch_size = batch_idx.max().item() + 1
        stats = []

        for stat in self.method:
            stat_embed = torch.zeros((batch_size, output.size(1)), device=output.device)
            for i in range(batch_size):
                mask = (batch_idx == i)
                if valid_mask is not None:
                    mask = mask & valid_mask  
                if mask.sum() == 0:
                    continue
                if stat == "sum":
                    stat_embed[i] = output[mask].sum(dim=0)
                elif stat == "mean":
                    stat_embed[i] = output[mask].mean(dim=0)
                elif stat == "max":
                    stat_embed[i], _ = output[mask].max(dim=0)
                elif stat == "std":
                    stat_embed[i] = output[mask].std(dim=0, unbiased=False)
            stats.append(stat_embed)
        graph_embedding = torch.cat(stats, dim=1)
        return self.mlp(graph_embedding)

class GINNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, num_readout_layers=2, method="sum", drop_prob=0.2):
        super(GINNet, self).__init__()
        self.input_mlp = MLP(in_channels, hidden_channels)
        # self.stn = STN3dFlexible() ########0421
        self.layers = nn.ModuleList([
            CustomGINConv(MLP(hidden_channels, hidden_channels))
            for _ in range(num_layers)
        ])
        self.readout = ReadoutWithMLP(
            hidden_channels, out_channels,
            num_hidden_layers=num_readout_layers,
            method=method
        )
        self.drop_prob = drop_prob

    def forward(self, x, edge_index, edge_weight, batch_idx, frame_idx=None):
        if self.training:
            # üîπ Ï†ÑÏ≤¥ Ìè¨Ïù∏Ìä∏ Ï§ë ÏùºÎ∂Ä ÎìúÎ°≠ (ÌôïÎ•† Í∏∞Î∞ò)
            mask = torch.rand(x.size(0), device=x.device) > self.drop_prob
        else:
            mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device)
        
        ### STN 0421
        # pos = x[:, :3]
        # aligned_pos = self.stn(pos, batch_idx)
        # x = torch.cat([aligned_pos, x[:, 3:].clone()], dim=1)  # ‚úÖ clone() ÌôïÏã§Ìûà
        ###

        x = x * mask.unsqueeze(1).float()
        x = self.input_mlp(x)
        for layer in self.layers:
            x = layer(x, edge_index, edge_weight)

        return self.readout(x, batch_idx, valid_mask=mask)

class GraphClassifier(nn.Module):
    def __init__(self, in_channels, num_classes, num_hidden_layers=2, dropout_rate=0.5):

        super(GraphClassifier, self).__init__()
        hidden_dims = []
        prev_dim = in_channels
        for i in range(num_hidden_layers):
            hidden_dims.append(prev_dim // 2)
            prev_dim = prev_dim // 2  
            if prev_dim < num_classes:  
                break
            
        layers = []
        prev_dim = in_channels
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            # layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim 
        layers.append(nn.Linear(prev_dim, num_classes))  
        self.mlp = nn.Sequential(*layers)

    def forward(self, graph_embedding):
        return self.mlp(graph_embedding)



In [None]:
# ============================================
# 5-FOLD CV: Multi-Task (Action + User) with Shared Backbone
#  - Backbone: GINBackbone
#  - Action head: TemporalGatedReadout + Classifier
#  - User   head: StatsRFFReadout     + Classifier
#  - Stratified K-fold on ACTION labels 
# ============================================
import os, time, copy, math, numpy as np, torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader as PYGLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Utils ----------
def set_seed(seed=42):
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def normalize_time_per_graph(time_steps, batch):
    device = batch.device
    t = time_steps.to(device).float()
    out = torch.zeros_like(t, device=device)
    B = int(batch.max().item()) + 1 if batch.numel() else 1
    for b in range(B):
        m = (batch == b)
        if m.any():
            tb = t[m]; mn, mx = tb.min(), tb.max()
            out[m] = (tb - mn) / (mx - mn + 1e-6)
    return out.unsqueeze(1)  # [N,1]

def per_graph_softmax(scores: torch.Tensor, batch: torch.Tensor, topk_ratio: float = None) -> torch.Tensor:
    out = torch.zeros_like(scores)
    B = int(batch.max().item()) + 1 if batch.numel() else 1
    for b in range(B):
        m = (batch == b)
        s = scores[m]
        if s.numel() == 0:
            continue
        if topk_ratio is not None and 0.0 < topk_ratio < 1.0:
            s1 = s.squeeze(1)
            k = max(1, int(math.ceil(topk_ratio * s1.numel())))
            topk_idx = torch.topk(s1, k).indices
            mask1 = torch.zeros_like(s1, dtype=torch.bool); mask1[topk_idx] = True
            s_sel  = s1[mask1]
            s_norm = torch.softmax(s_sel, dim=0)
            tmp1 = torch.zeros_like(s1); tmp1[mask1] = s_norm
            out[m] = tmp1.unsqueeze(1)
        else:
            out[m] = torch.softmax(s, dim=0)
    return out

# ---------- Backbone ----------
class GINBackbone(nn.Module):
    def __init__(self, in_channels=4, hidden_channels=64, num_layers=2, drop_prob=0.0):
        super().__init__()
        self.input_mlp = MLP(in_channels, hidden_channels)
        self.layers = nn.ModuleList([CustomGINConv(MLP(hidden_channels, hidden_channels))
                                     for _ in range(num_layers)])
        self.drop_prob = drop_prob
    def forward(self, x, edge_index, edge_weight):
        if self.training and self.drop_prob > 0.0:
            mask = (torch.rand(x.size(0), device=x.device) > self.drop_prob)
            x = x * mask.unsqueeze(1).float()
        x = self.input_mlp(x)
        for layer in self.layers:
            x = layer(x, edge_index, edge_weight)
        return x

# ---------- Readouts ----------
class TemporalGatedReadout(nn.Module):
    def __init__(self, in_dim, out_dim, topk_ratio=0.25):
        super().__init__()
        self.wh = nn.Linear(in_dim, in_dim)
        self.wt = nn.Linear(1, in_dim)
        self.score = nn.Linear(in_dim, 1)
        self.topk_ratio = topk_ratio
        self.proj = nn.Sequential(
            nn.Linear(in_dim * 3, out_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(out_dim),
        )
    def forward(self, node_h, batch, time_steps):
        t_norm = normalize_time_per_graph(time_steps, batch)
        s = torch.tanh(self.wh(node_h) + self.wt(t_norm))
        scores = self.score(s)
        alpha = per_graph_softmax(scores, batch, self.topk_ratio)
        attn_pool = global_add_pool(alpha * node_h, batch)
        mean_pool = global_mean_pool(node_h, batch)
        max_pool  = global_max_pool(node_h, batch)
        return self.proj(torch.cat([attn_pool, mean_pool, max_pool], dim=1))

class StatsRFFReadout(nn.Module):
    def __init__(self, in_dim, out_dim, rff_dim=64):
        super().__init__()
        self.B = nn.Parameter(torch.randn(in_dim, rff_dim) * 0.5)
        self.b = nn.Parameter(torch.rand(rff_dim) * 2 * math.pi)
        self.scale = math.sqrt(2.0 / rff_dim)
        self.proj = nn.Sequential(
            nn.Linear(in_dim * 2 + rff_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(out_dim),
        )
    def forward(self, node_h, batch, time_steps=None):
        mu  = global_mean_pool(node_h, batch)
        xc  = node_h - mu[batch]
        var = global_mean_pool(xc * xc, batch)
        std = torch.sqrt(var + 1e-6)
        z = torch.cos(node_h @ self.B + self.b) * self.scale
        from torch_geometric.nn import global_mean_pool as gmp
        kme = gmp(z, batch)
        return self.proj(torch.cat([mu, std, kme], dim=1))

class MultiTaskModel(nn.Module):
    def __init__(self, in_dim=4, h_dim=64, r_dim=1024, num_layers=2,
                 drop_prob=0.0, topk_ratio=0.25, rff_dim=64,
                 num_classes_act=21, num_classes_usr=41):
        super().__init__()
        self.backbone   = GINBackbone(in_dim, h_dim, num_layers=num_layers, drop_prob=drop_prob)
        self.ro_action  = TemporalGatedReadout(h_dim, r_dim, topk_ratio=topk_ratio)
        self.ro_user    = StatsRFFReadout(h_dim, r_dim, rff_dim=rff_dim)
        self.cls_action = GraphClassifier(in_channels=r_dim, num_classes=num_classes_act, num_hidden_layers=2, dropout_rate=0.5)
        self.cls_user   = GraphClassifier(in_channels=r_dim, num_classes=num_classes_usr, num_hidden_layers=2, dropout_rate=0.5)

    def forward(self, x, edge_index, edge_weight, batch, time_steps=None):
        h = self.backbone(x, edge_index, edge_weight)
        g_act = self.ro_action(h, batch, time_steps)
        g_usr = self.ro_user(h, batch, time_steps)
        logits_act = self.cls_action(g_act)
        logits_usr = self.cls_user(g_usr)
        return {'act': logits_act, 'user': logits_usr}

    def forward_act(self, x, edge_index, edge_weight, batch, time_steps=None):
        h = self.backbone(x, edge_index, edge_weight)
        g = self.ro_action(h, batch, time_steps)
        return self.cls_action(g)

    def forward_user(self, x, edge_index, edge_weight, batch, time_steps=None):
        h = self.backbone(x, edge_index, edge_weight)
        g = self.ro_user(h, batch, time_steps)
        return self.cls_user(g)

# ---------- Dataset build ----------
def build_multitask_graphs_all(graphs, g_labels, u_labels, missing_user=-100, seed=42):
    set_seed(seed)
    assert len(graphs) == len(g_labels) == len(u_labels), "Í∏∏Ïù¥ Ïïà ÎßûÏùå"
    gs = []
    for g, ya, yu in zip(graphs, g_labels, u_labels):
        gg = copy.deepcopy(g)
        gg.y_act  = torch.tensor(int(ya), dtype=torch.long)
        yu_val = int(yu) if yu is not None else missing_user
        gg.y_user = torch.tensor(yu_val, dtype=torch.long)
        gs.append(gg)
    return gs

def make_loaders(train_ds, val_ds, test_ds, bs=32, workers=0):
    return (
        PYGLoader(train_ds, batch_size=bs, shuffle=True,  num_workers=workers),
        PYGLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=workers),
        PYGLoader(test_ds,  batch_size=bs, shuffle=False, num_workers=workers),
    )

# ---------- Stratified K-fold (on action labels) ----------
def stratified_kfold_indices(y, k=5, seed=42):
    rng = np.random.default_rng(seed)
    y = np.asarray(y, dtype=np.int64)
    folds = [[] for _ in range(k)]
    for c in np.unique(y):
        idx_c = np.where(y == c)[0]
        rng.shuffle(idx_c)
        for i, idx in enumerate(idx_c):
            folds[i % k].append(int(idx))
    return [np.array(f, dtype=np.int64) for f in folds]

def stratified_train_val_split(indices, y, val_ratio=0.1, seed=123):
    rng = np.random.default_rng(seed)
    indices = np.asarray(indices, dtype=np.int64)
    y_pool = y[indices]
    train_idx, val_idx = [], []
    for c in np.unique(y_pool):
        idx_c = indices[y_pool == c].copy()
        rng.shuffle(idx_c)
        n_val = max(1, int(len(idx_c) * val_ratio))
        val_idx.extend(idx_c[:n_val].tolist())
        train_idx.extend(idx_c[n_val:].tolist())
    rng.shuffle(train_idx); rng.shuffle(val_idx)
    return np.array(train_idx, dtype=np.int64), np.array(val_idx, dtype=np.int64)

# ---------- Eval helper ----------
@torch.no_grad()
def eval_multitask(model, loader, missing_user=-100):
    model.eval()
    vc_a = vt_a = 0
    vc_u = vt_u = 0
    for batch in loader:
        batch = batch.to(device)
        y_act  = batch.y_act.view(-1)
        y_user = batch.y_user.view(-1)

        x, ei = batch.x, batch.edge_index
        ew = batch.edge_weight if hasattr(batch,'edge_weight') and batch.edge_weight is not None \
             else torch.ones(ei.size(1), device=x.device)
        bidx   = batch.batch
        tsteps = batch.time_steps if hasattr(batch,'time_steps') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        outs = model(x, ei, ew, bidx, tsteps)
        la, lu = outs['act'], outs['user']

        pa = la.argmax(1)
        vc_a += (pa == y_act).sum().item()
        vt_a += y_act.size(0)

        mu = (y_user != missing_user)
        if mu.any():
            pu = lu.argmax(1)
            vc_u += (pu[mu] == y_user[mu]).sum().item()
            vt_u += int(mu.sum().item())

    acc_a = vc_a / max(vt_a, 1)
    acc_u = (vc_u / max(vt_u, 1)) if vt_u > 0 else 0.0
    return acc_a, acc_u

# ---------- Train one fold (with early stop on combo) ----------
def train_one_fold(train_loader, val_loader, test_loader,
                   epochs=90, lr=5e-4, max_lr=3e-3, lambda_user=1.0,
                   h_dim=64, r_dim=1024, num_layers=2,
                   topk_ratio=0.25, rff_dim=64,
                   num_classes_act=49, num_classes_usr=11,
                   missing_user=-100,
                   patience=12, min_delta=1e-4, warmup_epochs=5,
                   ckpt_path="prop_multitask_fold0.pth",
                   seed=42):
    set_seed(seed)

    model = MultiTaskModel(
        in_dim=4, h_dim=h_dim, r_dim=r_dim, num_layers=num_layers,
        topk_ratio=topk_ratio, rff_dim=rff_dim,
        num_classes_act=num_classes_act, num_classes_usr=num_classes_usr
    ).to(device)

    opt   = Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = OneCycleLR(opt, max_lr=max_lr, steps_per_epoch=len(train_loader), epochs=epochs)
    ce    = nn.CrossEntropyLoss(reduction='mean')

    best_combo = -1.0
    best_state = None
    no_improve = 0

    for ep in range(epochs):
        model.train()
        for batch in train_loader:
            batch = batch.to(device)

            y_act  = batch.y_act.view(-1)
            y_user = batch.y_user.view(-1)

            x, ei = batch.x, batch.edge_index
            ew = batch.edge_weight if hasattr(batch,'edge_weight') and batch.edge_weight is not None \
                 else torch.ones(ei.size(1), device=x.device)
            bidx   = batch.batch
            tsteps = batch.time_steps if hasattr(batch,'time_steps') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

            outs = model(x, ei, ew, bidx, tsteps)
            logits_a, logits_u = outs['act'], outs['user']

            loss_a = ce(logits_a, y_act)

            mask_u = (y_user != missing_user)
            if mask_u.any():
                loss_u = ce(logits_u[mask_u], y_user[mask_u])
            else:
                loss_u = torch.tensor(0.0, device=device)

            loss = loss_a + lambda_user * loss_u

            opt.zero_grad()
            loss.backward()
            opt.step()
            sched.step()

        # val
        val_acc_a, val_acc_u = eval_multitask(model, val_loader, missing_user=missing_user)
        combo = 0.5 * (val_acc_a + val_acc_u)

        print(f"[FOLD] Ep {ep+1:02d} | Val A:{val_acc_a:.4f} U:{val_acc_u:.4f} | Combo:{combo:.4f}")

        if (combo - best_combo) > min_delta:
            best_combo = combo
            best_state = copy.deepcopy(model.state_dict())
            torch.save(best_state, ckpt_path)
            no_improve = 0
            print("‚úÖ Saved:", ckpt_path)
        else:
            if (ep + 1) > warmup_epochs:
                no_improve += 1

        if (ep + 1) > warmup_epochs and no_improve >= patience:
            print(f"‚èπÔ∏è Early stop (patience={patience})")
            break

    # test (best)
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    test_acc_a, test_acc_u = eval_multitask(model, test_loader, missing_user=missing_user)
    print(f"üî• [FOLD] Test Acc ‚Äî Action: {test_acc_a:.4f} | User: {test_acc_u:.4f}")

    return {
        "best_combo": float(best_combo),
        "test_acc_action": float(test_acc_a),
        "test_acc_user": float(test_acc_u),
        "ckpt": ckpt_path
    }

# ---------- 5-fold CV wrapper ----------
def train_multitask_5fold(
    k_folds=5,
    epochs=90, lr=5e-4, max_lr=3e-3, lambda_user=1.0,
    h_dim=64, r_dim=1024, num_layers=2,
    bs=32, workers=0,
    seed=42,
    topk_ratio=0.25, rff_dim=64,
    num_classes_act=21, num_classes_usr=41,
    missing_user=-100,
    val_ratio=0.1,
    fold_seed=42, val_seed=123,
    patience=12, min_delta=1e-4, warmup_epochs=5,
    ckpt_dir="./ckpt_multitask_5fold",
):
    os.makedirs(ckpt_dir, exist_ok=True)
    set_seed(seed)

    # 0) full dataset (labels embedded)
    gs = build_multitask_graphs_all(
        filtered_graphs, filtered_g_labels, filtered_u_labels,
        missing_user=missing_user, seed=seed
    )

    # 1) folds stratified by ACTION labels
    y_act_all = np.array([int(g.y_act.item()) for g in gs], dtype=np.int64)
    folds = stratified_kfold_indices(y_act_all, k=k_folds, seed=fold_seed)

    all_idx = np.arange(len(gs), dtype=np.int64)

    fold_results = []
    for fold in range(k_folds):
        test_idx = folds[fold]
        trainval_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)

        # train/val split stratified by ACTION within trainval
        train_idx, val_idx = stratified_train_val_split(
            trainval_idx, y_act_all, val_ratio=val_ratio, seed=val_seed + fold
        )

        train_ds = [gs[i] for i in train_idx]
        val_ds   = [gs[i] for i in val_idx]
        test_ds  = [gs[i] for i in test_idx]

        train_loader, val_loader, test_loader = make_loaders(train_ds, val_ds, test_ds, bs=bs, workers=workers)

        ckpt_path = os.path.join(ckpt_dir, f"prop_multitask_shared_backbone_fold{fold}.pth")

        print(f"\n==================== FOLD {fold}/{k_folds-1} ====================")
        print(f"[SPLIT] train/val/test = {len(train_ds)}/{len(val_ds)}/{len(test_ds)} | bs={bs}")

        r = train_one_fold(
            train_loader, val_loader, test_loader,
            epochs=epochs, lr=lr, max_lr=max_lr, lambda_user=lambda_user,
            h_dim=h_dim, r_dim=r_dim, num_layers=num_layers,
            topk_ratio=topk_ratio, rff_dim=rff_dim,
            num_classes_act=num_classes_act, num_classes_usr=num_classes_usr,
            missing_user=missing_user,
            patience=patience, min_delta=min_delta, warmup_epochs=warmup_epochs,
            ckpt_path=ckpt_path,
            seed=seed + fold
        )

        r.update({
            "fold": int(fold),
            "sizes": {"train": int(len(train_ds)), "val": int(len(val_ds)), "test": int(len(test_ds))}
        })
        fold_results.append(r)

    # summary
    a_list = [r["test_acc_action"] for r in fold_results]
    u_list = [r["test_acc_user"] for r in fold_results]
    a_mean, a_std = float(np.mean(a_list)), float(np.std(a_list))
    u_mean, u_std = float(np.mean(u_list)), float(np.std(u_list))

    print("\n==================== 5-FOLD SUMMARY ====================")
    for r in fold_results:
        sz = r["sizes"]
        print(f"fold{r['fold']} | TestA {r['test_acc_action']:.4f} | TestU {r['test_acc_user']:.4f} | "
              f"train/val/test={sz['train']}/{sz['val']}/{sz['test']} | ckpt={r['ckpt']}")
    print(f"MEAN¬±STD Action = {a_mean:.4f} ¬± {a_std:.4f}")
    print(f"MEAN¬±STD User   = {u_mean:.4f} ¬± {u_std:.4f}")

    return {
        "folds": fold_results,
        "mean_action": a_mean, "std_action": a_std,
        "mean_user": u_mean, "std_user": u_std,
        "ckpt_dir": ckpt_dir
    }

set_seed(42)
res_cv = train_multitask_5fold(
    k_folds=5,
    epochs=100, lr=5e-4, max_lr=1e-3, lambda_user=1.0,
    h_dim=64, r_dim=1024, num_layers=2,
    bs=16, workers=0,
    seed=42,
    topk_ratio=0.25, rff_dim=64,
    num_classes_act=21, num_classes_usr=41,
    missing_user=-100,
    val_ratio=0.1,
    fold_seed=42, val_seed=123,
    patience=12, min_delta=1e-4, warmup_epochs=5,
    ckpt_dir="./ckpt_multitask_5fold"
)


print("\n===== [MT-5F] Summary =====")
print(f"Action mean¬±std = {res_cv['mean_action']:.4f} ¬± {res_cv['std_action']:.4f}")
print(f"User   mean¬±std = {res_cv['mean_user']:.4f} ¬± {res_cv['std_user']:.4f}")
print("CKPT dir:", res_cv["ckpt_dir"])


model profile

In [None]:
# ============================================
# TEST/PROFILE CELL: MultiTask 5-Fold (Action+User) CKPT -> ACC + SIZE/PARAMS/FLOPs/LAT
#  - Loads each fold checkpoint from ckpt_dir
#  - Rebuilds same folds (stratified on ACTION)
#  - Reports per-fold TestAcc(Action/User), ckpt size, params, FLOPs (if thop works), latency (ms)
#  - Copy & Paste ready
# ============================================
import os, time, math, copy, numpy as np, torch
import torch.nn as nn
from collections import OrderedDict
from torch_geometric.loader import DataLoader as PYGLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# MUST MATCH TRAIN SETTINGS
# ---------------------------
CKPT_DIR = "./ckpt_multitask_5fold"
K_FOLDS  = 5
BS_TEST  = 1          # profiling/eval batch size (can differ from train)
WORKERS  = 0
MISSING_USER = -100

# model hyperparams must match train
IN_DIM = 4
H_DIM  = 64
R_DIM  = 1024
NUM_LAYERS = 2
TOPK_RATIO = 0.25
RFF_DIM = 64

# If you changed these to "auto" in training, set them here too (recommended)
NUM_CLASSES_ACT = 21
NUM_CLASSES_USR = 41

# thop settings
THOP_WARMUP_BATCHES = 1

# latency settings (forward only)
LAT_WARMUP = 30
LAT_RUNS   = 200

# ---------------------------
# Seed / folds (same as train)
# ---------------------------
def set_seed(seed=42):
    import random
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def stratified_kfold_indices(y, k=5, seed=42):
    rng = np.random.default_rng(seed)
    y = np.asarray(y, dtype=np.int64)
    folds = [[] for _ in range(k)]
    for c in np.unique(y):
        idx_c = np.where(y == c)[0]
        rng.shuffle(idx_c)
        for i, idx in enumerate(idx_c):
            folds[i % k].append(int(idx))
    return [np.array(f, dtype=np.int64) for f in folds]

def stratified_train_val_split(indices, y, val_ratio=0.1, seed=123):
    rng = np.random.default_rng(seed)
    indices = np.asarray(indices, dtype=np.int64)
    y_pool = y[indices]
    train_idx, val_idx = [], []
    for c in np.unique(y_pool):
        idx_c = indices[y_pool == c].copy()
        rng.shuffle(idx_c)
        n_val = max(1, int(len(idx_c) * val_ratio))
        val_idx.extend(idx_c[:n_val].tolist())
        train_idx.extend(idx_c[n_val:].tolist())
    rng.shuffle(train_idx); rng.shuffle(val_idx)
    return np.array(train_idx, dtype=np.int64), np.array(val_idx, dtype=np.int64)

# ---------------------------
# (Re)build dataset with labels embedded exactly like train
# ---------------------------
def build_multitask_graphs_all(graphs, g_labels, u_labels, missing_user=-100, seed=42):
    set_seed(seed)
    assert len(graphs) == len(g_labels) == len(u_labels), "length mismatch"
    gs = []
    for g, ya, yu in zip(graphs, g_labels, u_labels):
        gg = copy.deepcopy(g)
        gg.y_act  = torch.tensor(int(ya), dtype=torch.long)
        yu_val = int(yu) if (yu is not None) else missing_user
        gg.y_user = torch.tensor(yu_val, dtype=torch.long)
        gs.append(gg)
    return gs

def make_loader(ds, bs=64, workers=0, shuffle=False):
    return PYGLoader(ds, batch_size=bs, shuffle=shuffle, num_workers=workers)

# ---------------------------
# Model (must already be defined in previous cell)
#   - GINBackbone, TemporalGatedReadout, StatsRFFReadout, GraphClassifier, MultiTaskModel
# ---------------------------
def build_model(num_classes_act, num_classes_usr):
    model = MultiTaskModel(
        in_dim=IN_DIM, h_dim=H_DIM, r_dim=R_DIM, num_layers=NUM_LAYERS,
        topk_ratio=TOPK_RATIO, rff_dim=RFF_DIM,
        num_classes_act=num_classes_act,
        num_classes_usr=num_classes_usr
    ).to(device)
    return model

def load_state_1gpu(model, path):
    sd = torch.load(path, map_location=device)
    if isinstance(sd, dict) and any(k.startswith("module.") for k in sd.keys()):
        sd = OrderedDict((k.replace("module.", ""), v) for k, v in sd.items())
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model

# ---------------------------
# Accuracy eval (Action/User on test split)
# ---------------------------
@torch.no_grad()
def eval_multitask(model, loader, missing_user=-100):
    model.eval()
    vc_a = vt_a = 0
    vc_u = vt_u = 0
    for batch in loader:
        batch = batch.to(device)
        y_act  = batch.y_act.view(-1)
        y_user = batch.y_user.view(-1)

        x, ei = batch.x, batch.edge_index
        ew = batch.edge_weight if (hasattr(batch, "edge_weight") and batch.edge_weight is not None) \
             else torch.ones(ei.size(1), device=x.device)
        bidx   = batch.batch
        tsteps = batch.time_steps if hasattr(batch, "time_steps") else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        outs = model(x, ei, ew, bidx, tsteps)
        la, lu = outs["act"], outs["user"]

        pa = la.argmax(1)
        vc_a += (pa == y_act).sum().item()
        vt_a += y_act.numel()

        mu = (y_user != missing_user)
        if mu.any():
            pu = lu.argmax(1)
            vc_u += (pu[mu] == y_user[mu]).sum().item()
            vt_u += int(mu.sum().item())

    acc_a = vc_a / max(vt_a, 1)
    acc_u = (vc_u / max(vt_u, 1)) if vt_u > 0 else 0.0
    return float(acc_a), float(acc_u)

# ---------------------------
# Size / Params / FLOPs / Latency
# ---------------------------
def get_model_size_mb(path):
    return os.path.getsize(path) / (1024 * 1024)

def count_params_m(model):
    return sum(p.numel() for p in model.parameters()) / 1e6

def try_get_flops_m(model, example_batch):
    """
    Returns FLOPs(M) (not MACs) for one forward pass of MultiTaskModel.
    thop may fail depending on ops. If fails -> None.
    """
    try:
        from thop import profile
        batch = example_batch.to(device)
        x, ei = batch.x, batch.edge_index
        ew = batch.edge_weight if (hasattr(batch, "edge_weight") and batch.edge_weight is not None) \
             else torch.ones(ei.size(1), device=x.device)
        bidx = batch.batch
        tsteps = batch.time_steps if hasattr(batch, "time_steps") else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # thop expects plain tensors as inputs
        macs, _ = profile(model, inputs=(x, ei, ew, bidx, tsteps), verbose=False)
        flops = macs * 2.0
        return float(flops / 1e6)
    except Exception:
        return None

@torch.inference_mode()
def measure_latency_ms(model, example_batch, warmup=30, runs=200):
    batch = example_batch.to(device)
    x, ei = batch.x, batch.edge_index
    ew = batch.edge_weight if (hasattr(batch, "edge_weight") and batch.edge_weight is not None) \
         else torch.ones(ei.size(1), device=x.device)
    bidx = batch.batch
    tsteps = batch.time_steps if hasattr(batch, "time_steps") else torch.zeros(x.size(0), dtype=torch.long, device=x.device)

    # warmup
    for _ in range(warmup):
        _ = model(x, ei, ew, bidx, tsteps)

    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(runs):
        _ = model(x, ei, ew, bidx, tsteps)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / runs

def print_table(rows, title="MULTITASK 5F (TEST/PROFILE)"):
    print(f"\n==================== {title} ====================")
    hdr = f"{'Fold':>4} {'AccA':>8} {'AccU':>8} {'Size(MB)':>10} {'FLOPs(M)':>10} {'Params(M)':>10} {'Lat(ms)':>10}"
    print(hdr)
    print("-" * len(hdr))
    for r in rows:
        fold, accA, accU, size_mb, flops_m, params_m, lat_ms = r
        flops_str = f"{flops_m:10.2f}" if flops_m is not None else f"{'n/a':>10}"
        print(f"{fold:>4d} {accA:8.4f} {accU:8.4f} {size_mb:10.2f} {flops_str} {params_m:10.2f} {lat_ms:10.2f}")

# ---------------------------
# MAIN: reconstruct folds + load ckpts + profile
#   Requires these globals from your training cell:
#     filtered_graphs, filtered_g_labels, filtered_u_labels
# ---------------------------
assert "filtered_graphs" in globals() and "filtered_g_labels" in globals() and "filtered_u_labels" in globals(), \
    "ÌïÑÏöî: filtered_graphs / filtered_g_labels / filtered_u_labels (ÌïôÏäµ ÏÖÄÏóêÏÑú ÏÉùÏÑ±Îêú Í≤É)"

set_seed(42)

# dataset with labels embedded
gs = build_multitask_graphs_all(
    filtered_graphs, filtered_g_labels, filtered_u_labels,
    missing_user=MISSING_USER, seed=42
)

# folds stratified by action
y_act_all = np.array([int(g.y_act.item()) for g in gs], dtype=np.int64)
folds = stratified_kfold_indices(y_act_all, k=K_FOLDS, seed=42)

all_idx = np.arange(len(gs), dtype=np.int64)

rows = []
accA_list, accU_list = [], []

for fold in range(K_FOLDS):
    ckpt_path = os.path.join(CKPT_DIR, f"prop_multitask_shared_backbone_fold{fold}.pth")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Missing ckpt: {ckpt_path}")

    test_idx = folds[fold]
    test_ds  = [gs[i] for i in test_idx]
    test_loader = make_loader(test_ds, bs=BS_TEST, workers=WORKERS, shuffle=False)

    # example batch for FLOPs/lat
    example_batch = next(iter(test_loader))
    # (optional) warmup batch for CUDA init
    _ = example_batch.to(device)

    model = build_model(NUM_CLASSES_ACT, NUM_CLASSES_USR)
    model = load_state_1gpu(model, ckpt_path)

    # accuracy
    accA, accU = eval_multitask(model, test_loader, missing_user=MISSING_USER)

    # size/params/flops/lat
    size_mb  = get_model_size_mb(ckpt_path)
    params_m = count_params_m(model)
    flops_m  = try_get_flops_m(model, example_batch)
    lat_ms   = measure_latency_ms(model, example_batch, warmup=LAT_WARMUP, runs=LAT_RUNS)

    rows.append((fold+1, accA, accU, size_mb, flops_m, params_m, lat_ms))
    accA_list.append(accA); accU_list.append(accU)

print_table(rows)

print("\n===== SUMMARY =====")
print(f"Action: mean¬±std = {float(np.mean(accA_list)):.4f} ¬± {float(np.std(accA_list)):.4f}")
print(f"User  : mean¬±std = {float(np.mean(accU_list)):.4f} ¬± {float(np.std(accU_list)):.4f}")
print("CKPT dir:", CKPT_DIR)