In [1]:
import cv2
import os
import numpy as np
import random
import shutil
from scipy.ndimage import gaussian_filter, map_coordinates
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    precision_recall_curve,
    auc,
    confusion_matrix,
    classification_report
)


import json
from datetime import datetime


In [2]:
# --- 改进版 ResNet 模块（使用空洞卷积）---
class DilatedBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(DilatedBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride,
            padding=dilation, dilation=dilation, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class CustomResNetDilated(nn.Module):
    def __init__(self, block, layers, dilation_settings=None):
        super().__init__()
        self.inplanes = 64
        if dilation_settings is None:
            dilation_settings = {
                'layer1': 1,
                'layer2': 1,
                'layer3': 2,
                'layer4': 2
            }
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1, 
                                     dilation=dilation_settings['layer1'])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 
                                     dilation=dilation_settings['layer2'])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 
                                     dilation=dilation_settings['layer3'])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 
                                     dilation=dilation_settings['layer4'])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=dilation, 
                          downsample=downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = build_dilated_resnet()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * DilatedBottleneck.expansion, 2048)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # 处理3D输入 [B, T, C, H, W] or [B*T, C, H, W]
        if len(x.shape) == 5:
            B, T, C, H, W = x.shape
            x = x.view(B * T, C, H, W)
        
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class TemporalFPN(nn.Module):
    def __init__(self, in_channels=3, fpn_dims=[256, 128, 64]):
        super().__init__()
        
        # Bottom-up pathway
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=(3, 7, 7), stride=(1, 2, 2), 
                     padding=(1, 3, 3)),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv3d(16, 32, kernel_size=(3, 5, 5), stride=(1, 2, 2), 
                     padding=(1, 2, 2)),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 2, 2), 
                     padding=(1, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )

        # Top-down pathway
        self.lateral_conv1 = nn.Conv3d(64, fpn_dims[0], 1)
        self.lateral_conv2 = nn.Conv3d(32, fpn_dims[1], 1)
        self.lateral_conv3 = nn.Conv3d(16, fpn_dims[2], 1)

        self.up_conv1 = nn.Sequential(
            nn.ConvTranspose3d(fpn_dims[0], fpn_dims[1], kernel_size=(1, 2, 2), 
                             stride=(1, 2, 2)),
            nn.BatchNorm3d(fpn_dims[1]),
            nn.ReLU(inplace=True)
        )
        
        self.up_conv2 = nn.Sequential(
            nn.ConvTranspose3d(fpn_dims[1], fpn_dims[2], kernel_size=(1, 2, 2), 
                             stride=(1, 2, 2)),
            nn.BatchNorm3d(fpn_dims[2]),
            nn.ReLU(inplace=True)
        )

        # Final fusion layer
        self.fusion = nn.Sequential(
            nn.Conv3d(sum(fpn_dims), 128, 1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 3, 1)
        )

    def forward(self, x):
        # 保存原始形状
        B, T, C, H, W = x.shape
        
        # 调整维度顺序：[B, T, C, H, W] -> [B, C, T, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        # Bottom-up pathway
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)

        # Top-down pathway
        p3 = self.lateral_conv1(c3)
        p2 = self.up_conv1(p3) + self.lateral_conv2(c2)
        p1 = self.up_conv2(p2) + self.lateral_conv3(c1)

        # Feature pyramid fusion
        p3_up = F.interpolate(p3, size=p1.shape[2:], mode='trilinear', 
                            align_corners=True)
        p2_up = F.interpolate(p2, size=p1.shape[2:], mode='trilinear', 
                            align_corners=True)
        fused = torch.cat([p1, p2_up, p3_up], dim=1)
        
        # Final processing
        out = self.fusion(fused)
        out = F.interpolate(out, size=x.shape[2:], mode='trilinear', 
                          align_corners=True)
        
        # 调整回原始维度顺序并展平
        out = out.permute(0, 2, 1, 3, 4)  # [B, T, C, H, W]
        out = out.reshape(B * T, C, H, W)  # [B*T, C, H, W]
        
        return out


class DilatedGatedFusion(nn.Module):
    def __init__(self, feature_dim=2048, balance_threshold=0.3):
        super().__init__()
        self.balance_threshold = balance_threshold
        self.gate = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 2)
        )
        
    def forward(self, spatial, temporal):
        logits = self.gate(torch.cat([spatial, temporal], dim=1))
        # 动态约束：强制最小权重不低于threshold
        alpha = F.softmax(logits, dim=1)
        alpha = torch.clamp(alpha, min=self.balance_threshold, max=1-self.balance_threshold)
        alpha = alpha / alpha.sum(dim=1, keepdim=True)  # 重新归一化
        fused = alpha[:, 0:1] * spatial + alpha[:, 1:2] * temporal
        print(f"[Gate] Mean Weights - Spatial: {alpha[:, 0].mean():.4f}, "
              f"Temporal: {alpha[:, 1].mean():.4f}, ")

        return fused, alpha


# 多任务损失函数
class EnhancedMultiTaskLoss(nn.Module):
    def __init__(self, num_classes=2, alpha=2, beta=1, gamma=1, kesal=1):
        super().__init__()
        self.num_classes = num_classes
         
        self.alpha = alpha  # Consistency weight
        self.beta = beta    # Attention sparsity weight
        self.gamma = gamma  # Gate balance weight
        self.kesal = kesal
        # Base losses
        self.cls_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        
    def forward(self, outputs, labels, spatial_feats=None, temporal_feats=None, attn_weights=None, gate_weights=None):
        """
        Args:
            outputs: Model predictions [B, C]
            labels: Ground truth labels [B]
            spatial_feats: Spatial features [B*T, D]
            temporal_feats: Temporal features [B*T, D]
            attn_weights: Attention weights [B, T, 1]
            gate_weights: Gate weights [B*T, 2]
        """
        # 1. Classification Loss (always present)
        loss_cls = self.cls_loss(outputs['logits'], labels)
        
        # Initialize other losses with zero
        loss_consist = torch.tensor(0.0).to(loss_cls.device)
        loss_attn = torch.tensor(0.0).to(loss_cls.device)
        loss_gate = torch.tensor(0.0).to(loss_cls.device)
        
        # 2. Stream Consistency Loss (only if both features are provided)
        if spatial_feats is not None and temporal_feats is not None:
            spatial_probs = F.softmax(spatial_feats.mean(dim=0), dim=0)
            temporal_probs = F.softmax(temporal_feats.mean(dim=0), dim=0)
            loss_consist = (self.kl_loss(spatial_probs.log(), temporal_probs) + 
                          self.kl_loss(temporal_probs.log(), spatial_probs)) / 2
        
        # 3. Attention Sparsity Loss (only if attn_weights provided)
        if attn_weights is not None:
            attn_weights = attn_weights.squeeze(-1)  # [B, T]
            loss_attn = -torch.mean(attn_weights * torch.log(attn_weights + 1e-10))
        
        # 4. Gate Balance Loss (only if gate_weights provided)
        if gate_weights is not None:
            gate_mean = gate_weights.mean(dim=0)  # [2]
            loss_gate = torch.sum((gate_mean - 0.5)**2)
        
        # Weighted total loss
        total_loss = (self.alpha * loss_cls + 
                     self.beta * loss_consist +
                     self.gamma * loss_attn +
                     self.kesal * loss_gate)
        
        return {
            'total': total_loss,
            'classification': loss_cls,
            'consistency': loss_consist,
            'attention_sparsity': loss_attn,
            'gate_balance': loss_gate
        }

# 时间特征金字塔网络 (FPN)
class TemporalFPN(nn.Module):
    def __init__(self, in_channels=3, fpn_dims=[256, 128, 64]):
        super().__init__()
        
        # Bottom-up pathway (光流特征提取backbone)
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv3d(16, 32, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )

        # Top-down pathway (上采样和特征融合)
        self.lateral_conv1 = nn.Conv3d(64, fpn_dims[0], 1)
        self.lateral_conv2 = nn.Conv3d(32, fpn_dims[1], 1)
        self.lateral_conv3 = nn.Conv3d(16, fpn_dims[2], 1)

        self.up_conv1 = nn.Sequential(
            nn.ConvTranspose3d(fpn_dims[0], fpn_dims[1], kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.BatchNorm3d(fpn_dims[1]),
            nn.ReLU(inplace=True)
        )
        
        self.up_conv2 = nn.Sequential(
            nn.ConvTranspose3d(fpn_dims[1], fpn_dims[2], kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.BatchNorm3d(fpn_dims[2]),
            nn.ReLU(inplace=True)
        )

        # Final fusion layer
        self.fusion = nn.Sequential(
            nn.Conv3d(sum(fpn_dims), 128, 1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 3, 1)  # 输出通道数改回3，与原始模型保持一致
        )

    def forward(self, x):
        # 调整输入维度顺序：从[B, T, C, H, W]到[B, C, T, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        # Bottom-up pathway
        c1 = self.conv1(x)      # [B, 16, T, H/2, W/2]
        c2 = self.conv2(c1)     # [B, 32, T, H/4, W/4]
        c3 = self.conv3(c2)     # [B, 64, T, H/8, W/8]

        # Top-down pathway
        p3 = self.lateral_conv1(c3)              # [B, 256, T, H/8, W/8]
        p2 = self.up_conv1(p3) + self.lateral_conv2(c2)  # [B, 128, T, H/4, W/4]
        p1 = self.up_conv2(p2) + self.lateral_conv3(c1)  # [B, 64, T, H/2, W/2]

        # 上采样所有特征图到相同大小
        p3_up = F.interpolate(p3, size=p1.shape[2:], mode='trilinear', align_corners=True)
        p2_up = F.interpolate(p2, size=p1.shape[2:], mode='trilinear', align_corners=True)

        # 特征融合
        fused = torch.cat([p1, p2_up, p3_up], dim=1)
        out = self.fusion(fused)

        # 调整输出尺寸以匹配输入
        out = F.interpolate(out, size=x.shape[2:], mode='trilinear', align_corners=True)
        
        # 调整输出维度顺序：从[B, C, T, H, W]回到[B, T, C, H, W]
        out = out.permute(0, 2, 1, 3, 4)
        
        return out

    
class ImprovedTwoStreamModel(nn.Module):
    def __init__(self, num_classes=2, seg_num=45, d_model=2048, nhead=8):
        super().__init__()
        if d_model != 2048:
            print(f"Warning: d_model ({d_model}) does not match ResNet50 output (2048)")
        if d_model % nhead != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by nhead ({nhead})")

        self.seg_num = seg_num
        self.spatial_stream = build_dilated_resnet()
        self.temporal_stream = nn.Sequential(
            TemporalFPN(in_channels=3, fpn_dims=[256, 128, 64]),
            FeatureExtractor()
        )
        
        self.gated_fusion = DilatedGatedFusion(feature_dim=d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                 batch_first=True)
        self.attention = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.temporal_attn = nn.Sequential(
            nn.Linear(d_model, d_model // 8),
            nn.ReLU(),
            nn.Linear(d_model // 8, 1),
            nn.Softmax(dim=1)
        )
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, rgb, flow, return_cam=False):
        B = rgb.shape[0]
        T = self.seg_num
        rgb_in = rgb.view(-1, 3, 224, 224)

        # Spatial stream
        spatial_feat_map = self.spatial_stream(rgb_in)
        spatial_feat = F.adaptive_avg_pool2d(spatial_feat_map, (1, 1))
        spatial_feat = spatial_feat.view(spatial_feat_map.size(0), -1)

        # Store feature map for CAM
        spatial_features_cam = spatial_feat_map

        # Temporal stream
        temporal_feat = self.temporal_stream(flow)
        if temporal_feat.shape[0] != B * T:
            temporal_feat = temporal_feat.view(B * T, -1)

        # Feature fusion
        fused_feat, alpha = self.gated_fusion(spatial_feat, temporal_feat)
        fused_feat = fused_feat.view(B, T, -1)

        # Transformer attention
        attended_feat = self.attention(fused_feat)
        weights = self.temporal_attn(attended_feat)
        out = (attended_feat * weights).sum(dim=1)

        # Classification
        final_output = self.classifier(out)

        if return_cam:
            return final_output, spatial_features_cam
        return final_output



def build_dilated_resnet():
    dilation_settings = {
        'layer1': 2,
        'layer2': 2,
        'layer3': 2,
        'layer4': 1
    }
    return CustomResNetDilated(DilatedBottleneck, [1, 1, 1, 1], dilation_settings)
    #[3, 4, 6, 3]
    
# 带损失函数的双流模型
class ImprovedTwoStreamModelWithLoss(nn.Module):
    def __init__(self, num_classes=2, seg_num=45, d_model=2048, nhead=8,
                 alpha=0.5, beta=0.1, gamma=0.01, lambda_cam=0.1, use_cam_loss=False):
        super().__init__()
        
        # Two-stream model
        self.model = ImprovedTwoStreamModel(num_classes=num_classes, seg_num=seg_num, d_model=d_model, nhead=nhead)

        # Losses
        self.loss_fn = EnhancedMultiTaskLoss(num_classes=num_classes, alpha=alpha, beta=beta, gamma=gamma)
        self.use_cam_loss = use_cam_loss
        if use_cam_loss:
            self.cam_loss_fn = CAMLoss(lambda_cam=lambda_cam)

    def forward(self, rgb, flow, labels=None, gt_boxes=None, mode='train'):
        if mode == 'train':
            logits, spatial_feats_map = self.model(rgb, flow, return_cam=True)
        else:
            logits = self.model(rgb, flow, return_cam=False)

        if mode == 'infer':
            return logits

        # Extract features for loss calculation
        B, T = rgb.shape[0], self.model.seg_num
        rgb_in = rgb.view(-1, 3, 224, 224)
        
        # Get spatial features
        spatial_feat_map = self.model.spatial_stream(rgb_in)
        spatial_feat = F.adaptive_avg_pool2d(spatial_feat_map, (1, 1)).view(spatial_feat_map.size(0), -1)

        # Get temporal features
        temporal_feat = self.model.temporal_stream(flow)

        # Get attention weights
        fused_feat, gate_weights = self.model.gated_fusion(spatial_feat, temporal_feat)
        fused_feat = fused_feat.view(B, T, -1)
        attended_feat = self.model.attention(fused_feat)
        attn_weights = self.model.temporal_attn(attended_feat)  # [B, T, 1]

        outputs = {
            'logits': logits,
            'spatial_feats': spatial_feat,
            'temporal_feats': temporal_feat,
            'attn_weights': attn_weights,
            'gate_weights': gate_weights
        }

        return outputs

In [3]:
class VideoDataset(Dataset):
    def __init__(self, root_dir, seg_num=3, transform=None, split='train', train_ratio=0.7, test_ratio=0.3):
        self.root_dir = root_dir
        self.seg_num = seg_num
        self.transform = transform
        self.classes = {'feixianhua': 0, 'xianhua': 1} # Map class names to labels
        self.split = split
        self.train_ratio = train_ratio
        self.test_ratio = test_ratio
        self.samples = []
        self._load_samples()

    def _load_samples(self):
        all_samples_by_class = {}
        # First pass: Collect all video folders per class
        for class_name, class_label in self.classes.items():
            all_samples_by_class[class_name] = []
            rgb_dir = os.path.join(self.root_dir, f'{class_name}_rgb')
            flow_dir = os.path.join(self.root_dir, f'{class_name}_flow')

            if not os.path.exists(rgb_dir) or not os.path.exists(flow_dir):
                print(f"警告: 目录不存在 {rgb_dir} 或 {flow_dir} for class {class_name}")
                continue

            # Find common video folders
            try:
                rgb_folders = set(os.listdir(rgb_dir))
                flow_folders = set(os.listdir(flow_dir))
            except FileNotFoundError:
                print(f"Error listing directories within {rgb_dir} or {flow_dir}")
                continue
            video_folders = sorted(list(rgb_folders.intersection(flow_folders)))


            for video_name in video_folders:
                rgb_video_dir = os.path.join(rgb_dir, video_name)
                flow_video_dir = os.path.join(flow_dir, video_name)

                # Check if frame files exist (assuming frames 0 to seg_num-1)
                frame_indices = list(range(self.seg_num))
                rgb_paths = [os.path.join(rgb_video_dir, f"{i}.jpg") for i in frame_indices]
                flow_paths = [os.path.join(flow_video_dir, f"{i}.jpg") for i in frame_indices]

                # Ensure all required frames exist
                if all(os.path.exists(p) for p in rgb_paths) and all(os.path.exists(p) for p in flow_paths):
                    all_samples_by_class[class_name].append({
                        'rgb_paths': rgb_paths,
                        'flow_paths': flow_paths,
                        'label': class_label
                    })
                # else:
                #     print(f"Skipping {video_name}: Missing frames (expected 0 to {self.seg_num - 1})")


        # Second pass: Split samples for train/test
        for class_name, class_samples in all_samples_by_class.items():
            num_samples = len(class_samples)
            if num_samples == 0:
                print(f"Warning: No valid samples found for class {class_name}")
                continue

            indices = list(range(num_samples))
            # IMPORTANT: Shuffle indices BEFORE splitting for randomization
            # Uses 'np' which requires 'import numpy as np'
            np.random.shuffle(indices)

            train_end = int(num_samples * self.train_ratio)

            if self.split == 'train':
                split_indices = indices[:train_end]
            elif self.split == 'test':
                split_indices = indices[train_end:]
            else:
                raise ValueError("split must be 'train' or 'test'")

            self.samples.extend([class_samples[i] for i in split_indices])

        if not self.samples:
            raise ValueError(f"未找到任何有效样本 for split '{self.split}'! 请检查数据目录结构、文件命名规则 (0.jpg to {self.seg_num-1}.jpg) 和 seg_num.")

        print(f"\n成功加载数据 ({self.split}):")
        print(f"总样本数 (videos): {len(self.samples)}")
        for class_name, class_label in self.classes.items():
            count = sum(1 for s in self.samples if s['label'] == class_label)
            print(f"{class_name} 样本: {count}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        try:
            # Load and transform RGB frames
            rgb = torch.stack([self.transform(Image.open(p).convert('RGB')) for p in sample['rgb_paths']])
            # Load and transform Flow frames (assuming they are stored as RGB images)
            flow = torch.stack([self.transform(Image.open(p).convert('RGB')) for p in sample['flow_paths']])
        except Exception as e:
            print(f"Error loading sample at index {idx}, paths: {sample['rgb_paths'][:1]}...")
            print(e)
            # Return a dummy sample or re-raise
            # Returning dummy data might hide issues but allows training to continue
            dummy_frame = torch.zeros((3, 224, 224))
            rgb = torch.stack([dummy_frame] * self.seg_num)
            flow = torch.stack([dummy_frame] * self.seg_num)
            # return rgb, flow, sample['label']
            raise e # Re-raise to stop execution and fix the issue

        # Return shape: [T, C, H, W], [T, C, H, W], label
        return rgb, flow, sample['label']


In [4]:
# trainer.py - 训练和评估相关函数
class Trainer:
    def __init__(self, config, model, train_loader, test_loader, criterion, device):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.device = device
        
        # 初始化优化器和学习率调度器
        self.optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=config.BASE_LR,
            weight_decay=config.WEIGHT_DECAY
        )
        
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=config.LR_DECAY_EPOCHS,
            gamma=config.LR_DECAY_RATE
        )
        
        self.scaler = GradScaler()
        
        # 记录训练过程
        self.history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'test_loss': [], 'test_acc': [], 'test_f1': [],
            'lr': []
        }
        
    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        all_labels = []
        all_preds = []
        
        for batch_idx, (rgb, flow, labels) in enumerate(self.train_loader):
            rgb, flow, labels = rgb.to(self.device), flow.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            
            with autocast():
                outputs = self.model(rgb, flow, labels=labels, mode='train')
                loss_inputs = {
                    'logits': outputs['logits'],
                    'spatial_feats': outputs.get('spatial_feats'),
                    'temporal_feats': outputs.get('temporal_feats'),
                    'attn_weights': outputs.get('attn_weights'),
                    'gate_weights': outputs.get('gate_weights')
                }
                loss = self.criterion(loss_inputs, labels)
            
            self.scaler.scale(loss['total']).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            running_loss += loss['total'].item()
            _, predicted = torch.max(outputs['logits'], 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            
            if (batch_idx + 1) % 20 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx+1}/{len(self.train_loader)}, '
                      f'Loss: {running_loss/(batch_idx+1):.4f}')
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_f1 = f1_score(all_labels, all_preds, average='binary')
        
        return epoch_loss, epoch_acc, epoch_f1

    def evaluate(self, mode='test'):
        self.model.eval()
        loader = self.test_loader if mode == 'test' else self.train_loader
        
        running_loss = 0.0
        all_labels = []
        all_preds = []
        all_probs = []
        
        with torch.no_grad():
            for rgb, flow, labels in loader:
                rgb, flow, labels = rgb.to(self.device), flow.to(self.device), labels.to(self.device)
                
                with autocast():
                    outputs = self.model(rgb, flow, mode='eval')
                    loss = self.criterion({'logits': outputs['logits']}, labels)
                
                running_loss += loss['total'].item()
                probs = torch.softmax(outputs['logits'], dim=1)
                _, predicted = torch.max(outputs['logits'], 1)
                
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())
        
        metrics = {
            'loss': running_loss / len(loader),
            'accuracy': accuracy_score(all_labels, all_preds),
            'f1': f1_score(all_labels, all_preds, average='binary'),
            'precision': precision_score(all_labels, all_preds, average='binary'),
            'recall': recall_score(all_labels, all_preds, average='binary')
        }
        
        try:
            fpr, tpr, _ = roc_curve(all_labels, all_probs)
            metrics['roc_auc'] = auc(fpr, tpr)
            
            precision_pts, recall_pts, _ = precision_recall_curve(all_labels, all_probs)
            metrics['pr_auc'] = auc(recall_pts, precision_pts)
        except ValueError:
            metrics['roc_auc'] = 0.0
            metrics['pr_auc'] = 0.0
        
        return metrics

    def train(self):
        best_test_acc = 0.0
        best_model_path = 'best_model.pth'
        
        for epoch in range(self.config.NUM_EPOCHS):
            # 训练一个epoch
            train_loss, train_acc, train_f1 = self.train_epoch(epoch)
            
            # 评估
            train_metrics = self.evaluate(mode='train')
            test_metrics = self.evaluate(mode='test')
            
            # 更新学习率
            self.scheduler.step()
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # 记录历史
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_metrics['accuracy'])
            self.history['train_f1'].append(train_metrics['f1'])
            self.history['test_loss'].append(test_metrics['loss'])
            self.history['test_acc'].append(test_metrics['accuracy'])
            self.history['test_f1'].append(test_metrics['f1'])
            self.history['lr'].append(current_lr)
            
            # 保存最佳模型
            if test_metrics['accuracy'] > best_test_acc:
                best_test_acc = test_metrics['accuracy']
                torch.save(self.model.state_dict(), best_model_path)
                print(f"New best model saved with test accuracy: {best_test_acc:.4f}")
            
            # 打印epoch结果
            self._print_epoch_results(epoch, train_metrics, test_metrics, current_lr)
            
            # 绘制进度图
            if (epoch + 1) % 10 == 0:
                self.plot_training_progress()
        
        return best_model_path

    def _print_epoch_results(self, epoch, train_metrics, test_metrics, lr):
        print(f"\nEpoch {epoch+1}/{self.config.NUM_EPOCHS}")
        print(f"Learning Rate: {lr:.2e}")
        print("\nTraining Metrics:")
        for k, v in train_metrics.items():
            print(f"{k}: {v:.4f}")
        print("\nTest Metrics:")
        for k, v in test_metrics.items():
            print(f"{k}: {v:.4f}")

    def plot_training_progress(self):
        plt.figure(figsize=(15, 10))
        
        # Plot loss
        plt.subplot(2, 2, 1)
        plt.plot(self.history['train_loss'], label='Train')
        plt.plot(self.history['test_loss'], label='Test')
        plt.title('Loss vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        # Plot accuracy
        plt.subplot(2, 2, 2)
        plt.plot(self.history['train_acc'], label='Train')
        plt.plot(self.history['test_acc'], label='Test')
        plt.title('Accuracy vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        
        # Plot F1 Score
        plt.subplot(2, 2, 3)
        plt.plot(self.history['train_f1'], label='Train')
        plt.plot(self.history['test_f1'], label='Test')
        plt.title('F1 Score vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.legend()
        plt.grid(True)
        
        # Plot learning rate
        plt.subplot(2, 2, 4)
        plt.plot(self.history['lr'])
        plt.title('Learning Rate vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.yscale('log')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'training_progress.png')
        plt.close()


    def plot_final_curves(self, metrics, mode='test'):
        """
        绘制最终的评估曲线
        
        Args:
            metrics: 包含评估指标的字典，需要包含以下键：
                    'labels': 真实标签
                    'probs': 预测概率
                    'preds': 预测标签
                    'accuracy': 准确率
                    'precision': 精确率
                    'recall': 召回率
                    'f1': F1分数
                    'roc_auc': ROC曲线下面积
                    'pr_auc': PR曲线下面积
            mode: 'train' 或 'test'
        """
        plt.figure(figsize=(15, 5))
        
        # 1. ROC 曲线
        plt.subplot(1, 3, 1)
        fpr, tpr, _ = roc_curve(metrics['labels'], metrics['probs'])
        roc_auc = metrics['roc_auc']
        
        plt.plot(fpr, tpr, 'b-', label=f'ROC (AUC = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Final ROC Curve ({mode})')
        plt.legend(loc="lower right")
        
        # 2. PR 曲线
        plt.subplot(1, 3, 2)
        precision, recall, _ = precision_recall_curve(metrics['labels'], 
                                                    metrics['probs'])
        pr_auc = metrics['pr_auc']
        
        plt.plot(recall, precision, 'g-', 
                 label=f'PR (AUC = {pr_auc:.3f})')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title(f'Final PR Curve ({mode})')
        plt.legend(loc="lower left")
        
        # 3. 混淆矩阵
        plt.subplot(1, 3, 3)
        cm = confusion_matrix(metrics['labels'], metrics['preds'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, f'final_curves_{mode}.png'))
        plt.close()

        # 保存详细的评估指标
        metrics_summary = {
            'accuracy': metrics['accuracy'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'f1': metrics['f1'],
            'roc_auc': metrics['roc_auc'],
            'pr_auc': metrics['pr_auc']
        }
        
        # 将指标写入文本文件
        with open(os.path.join(self.output_dir, f'final_metrics_{mode}.txt'), 'w') as f:
            for metric_name, value in metrics_summary.items():
                f.write(f"{metric_name}: {value:.4f}\n")

In [5]:
# config.py - 配置文件
class Config:
    # 数据集参数
    ROOT_DATA_DIR = "hmdb_data_demo"
    SEG_NUM = 45
    TRAIN_RATIO = 0.7
    TEST_RATIO = 0.3
    
    # 训练参数
    BATCH_SIZE = 1
    NUM_EPOCHS = 30
    BASE_LR = 1e-4
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS = 2
    
    # 模型参数
    NUM_CLASSES = 2
    D_MODEL = 2048
    NHEAD = 8
    
    # Loss权重
    ALPHA = 1  # 分类损失熵
    BETA = 0.5  # 流一致性损失权重
    GAMMA = 0.5  # 注意力稀疏性损失权重
    KESAL = 0.4 # 门控平衡损失权重
    
    # 学习率调度
    LR_WARMUP_EPOCHS = 20
    LR_DECAY_EPOCHS = [15, 20, 30]
    LR_DECAY_RATE = 0.1

    K_FOLDS = 3
    SEED = 42  # 添加随机种子

In [None]:
# main.py - 主程序
if __name__ == "__main__":
    # 加载配置
    config = Config()
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 数据转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 创建数据加载器
    try:
        train_dataset = VideoDataset(
            config.ROOT_DATA_DIR, 
            seg_num=config.SEG_NUM,
            transform=transform,
            split='train',
            train_ratio=config.TRAIN_RATIO,
            test_ratio=config.TEST_RATIO
        )
        
        test_dataset = VideoDataset(
            config.ROOT_DATA_DIR,
            seg_num=config.SEG_NUM,
            transform=transform,
            split='test',
            train_ratio=config.TRAIN_RATIO,
            test_ratio=config.TEST_RATIO
        )
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=True,
            num_workers=config.NUM_WORKERS,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=False,
            num_workers=config.NUM_WORKERS,
            pin_memory=True
        )
    
    except Exception as e:
        print(f"Error creating datasets/dataloaders: {e}")
        exit(1)
    
    # 创建模型和损失函数
    model = ImprovedTwoStreamModelWithLoss(
        num_classes=config.NUM_CLASSES,
        seg_num=config.SEG_NUM,
        d_model=config.D_MODEL,
        nhead=config.NHEAD
    ).to(device)
    
    criterion = EnhancedMultiTaskLoss(
        alpha=config.ALPHA,
        beta=config.BETA,
        gamma=config.GAMMA,
        kesal=config.KESAL
    )
    
    # 创建训练器并开始训练
    trainer = Trainer(config, model, train_loader, test_loader, criterion, device)
    best_model_path = trainer.train()
    
    # 加载最佳模型并进行最终评估
    model.load_state_dict(torch.load(best_model_path))
    final_metrics = trainer.evaluate(mode='test')
    
    print("\nFinal Test Results:")
    for metric, value in final_metrics.items():
        print(f"{metric}: {value:.4f}")

   # 训练完成后的最终评估
    print("\nPerforming final evaluation...")
    model.load_state_dict(torch.load(best_model_path))
    final_train_metrics = trainer.evaluate(mode='train')
    final_test_metrics = trainer.evaluate(mode='test')
    
    # 绘制最终的评估曲线
    #trainer.plot_final_curves(final_train_metrics, mode='train')
   # trainer.plot_final_curves(final_test_metrics, mode='test')
    
    print("\nFinal Training Results:")
    for metric, value in final_train_metrics.items():
        if isinstance(value, (float, int)):
            print(f"{metric}: {value:.4f}")
    
    print("\nFinal Test Results:")
    for metric, value in final_test_metrics.items():
        if isinstance(value, (float, int)):
            print(f"{metric}: {value:.4f}")


成功加载数据 (train):
总样本数 (videos): 38
feixianhua 样本: 20
xianhua 样本: 18

成功加载数据 (test):
总样本数 (videos): 17
feixianhua 样本: 9
xianhua 样本: 8


  self.scaler = GradScaler()
  with autocast():


[Gate] Mean Weights - Spatial: 0.4799, Temporal: 0.5201, 
[Gate] Mean Weights - Spatial: 0.4799, Temporal: 0.5201, 
[Gate] Mean Weights - Spatial: 0.4805, Temporal: 0.5195, 
[Gate] Mean Weights - Spatial: 0.4805, Temporal: 0.5195, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


New best model saved with test accuracy: 0.5294

Epoch 2/30
Learning Rate: 1.00e-04

Training Metrics:
loss: 1.1232
accuracy: 0.5263
f1: 0.0000
precision: 0.0000
recall: 0.0000
roc_auc: 0.5625
pr_auc: 0.6078

Test Metrics:
loss: 1.1156
accuracy: 0.5294
f1: 0.0000
precision: 0.0000
recall: 0.0000
roc_auc: 0.6806
pr_auc: 0.6760


  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

  with autocast():


[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean Weights - Spatial: 0.3000, Temporal: 0.7000, 
[Gate] Mean We

In [4]:
    # 绘制最终的评估曲线
#trainer.plot_final_curves(final_train_metrics, mode='train')
# 评估模型
final_test_metrics = trainer.evaluate(mode='test')
trainer.plot_final_curves(final_test_metrics, mode='test')

NameError: name 'trainer' is not defined

In [None]:
#plot_final_curves(final_test_metrics, mode='test')
print("\nFinal Training Results:")
for metric, value in final_train_metrics.items():
    if isinstance(value, (float, int)):
        print(f"{metric}: {value:.4f}")
        
print("\nFinal Test Results:")
for metric, value in final_test_metrics.items():
        if isinstance(value, (float, int)):
            print(f"{metric}: {value:.4f}")

In [8]:
class CAMVisualizer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.model.eval()
        
    def generate_cam(self, feature_maps, class_weights, class_idx):
        """
        生成类激活图
        
        Args:
            feature_maps: 特征图 [C, H, W]
            class_weights: 分类层权重 [num_classes, C]
            class_idx: 目标类别索引
        """
        # 获取目标类别的权重
        target_weights = class_weights[class_idx]  # [C]
        
        # 计算加权和
        cam = torch.einsum('chw,c->hw', feature_maps, target_weights)
        
        # ReLU处理，只保留正值
        cam = F.relu(cam)
        
        # 归一化到0-1
        if cam.max() != cam.min():
            cam = (cam - cam.min()) / (cam.max() - cam.min())
            
        return cam.cpu().numpy()

    def visualize_cam(self, rgb_frames, flow_frames, target_class=None, save_path=None):
        """
        为视频序列生成CAM可视化
        
        Args:
            rgb_frames: RGB帧序列 [T, C, H, W]
            flow_frames: 光流帧序列 [T, C, H, W]
            target_class: 目标类别（如果为None，则使用预测类别）
            save_path: 保存路径
        """
        rgb_frames = rgb_frames.to(self.device)
        flow_frames = flow_frames.to(self.device)
        
        B = 1  # 批次大小为1
        T = rgb_frames.shape[0]
        
        with torch.no_grad():
            # 获取模型预测和特征图
            outputs, spatial_features = self.model(
                rgb_frames.unsqueeze(0),  # 添加批次维度
                flow_frames.unsqueeze(0),
                return_cam=True
            )
            
            # 获取预测类别
            probs = F.softmax(outputs, dim=1)
            if target_class is None:
                target_class = torch.argmax(probs, dim=1).item()
            
            # 获取分类器权重
            class_weights = self.model.classifier.weight.data
            
            # 生成每一帧的CAM
            spatial_features = spatial_features.squeeze(0)  # [T*C, H', W']
            spatial_features = spatial_features.view(T, -1, *spatial_features.shape[-2:])
            
            # 创建图像网格
            num_cols = 5
            num_rows = (T + num_cols - 1) // num_cols
            plt.figure(figsize=(20, 4 * num_rows))
            
            for t in range(T):
                # 原始帧
                rgb_frame = rgb_frames[t].cpu().permute(1, 2, 0).numpy()
                rgb_frame = (rgb_frame - rgb_frame.min()) / (rgb_frame.max() - rgb_frame.min())
                
                # 生成CAM
                cam = self.generate_cam(
                    spatial_features[t],
                    class_weights,
                    target_class
                )
                
                # 将CAM调整到与原始图像相同大小
                cam_resized = cv2.resize(cam, (rgb_frame.shape[1], rgb_frame.shape[0]))
                
                # 创建热力图
                heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
                heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
                
                # 叠加原始图像和热力图
                overlay = 0.6 * rgb_frame + 0.4 * heatmap
                overlay = np.clip(overlay, 0, 1)
                
                # 显示图像
                plt.subplot(num_rows, num_cols, t + 1)
                plt.imshow(overlay)
                plt.title(f'Frame {t} (Class {target_class})')
                plt.axis('off')
            
            plt.tight_layout()
            if save_path:
                plt.savefig(save_path)
                plt.close()
            else:
                plt.show()
            
    def visualize_sequence(self, dataset, sequence_idx, save_dir='cam_visualizations'):
        """
        可视化数据集中的特定序列
        
        Args:
            dataset: VideoDataset实例
            sequence_idx: 序列索引
            save_dir: 保存目录
        """
        os.makedirs(save_dir, exist_ok=True)
        
        # 获取序列数据
        rgb_frames, flow_frames, true_label = dataset[sequence_idx]
        
        # 生成并保存CAM可视化
        save_path = os.path.join(save_dir, f'sequence_{sequence_idx}_class_{true_label}.png')
        self.visualize_cam(rgb_frames, flow_frames, target_class=true_label, save_path=save_path)
        
        print(f"CAM visualization saved to {save_path}")

def visualize_dataset_samples(model, dataset, num_samples=5, save_dir='cam_visualizations'):
    """
    可视化数据集中的多个样本
    
    Args:
        model: 训练好的模型
        dataset: VideoDataset实例
        num_samples: 要可视化的样本数量
        save_dir: 保存目录
    """
    device = next(model.parameters()).device
    visualizer = CAMVisualizer(model, device)
    
    # 随机选择样本
    indices = random.sample(range(len(dataset)), num_samples)
    
    for idx in indices:
        visualizer.visualize_sequence(dataset, idx, save_dir)


In [9]:
def filter_state_dict(saved_state_dict, model_state_dict):
    """
    Filters the saved state_dict to match the model's state_dict structure.
    Args:
        saved_state_dict (dict): The state_dict from the checkpoint.
        model_state_dict (dict): The state_dict of the current model.
    Returns:
        dict: Filtered state_dict compatible with the current model.
    """
    filtered_state_dict = {}
    for key, value in saved_state_dict.items():
        if key in model_state_dict:
            filtered_state_dict[key] = value
        else:
            print(f"Skipping unexpected key: {key}")
    return filtered_state_dict

if __name__ == "__main__":
    # Initialize the model
    model = ImprovedTwoStreamModel(num_classes=2, seg_num=45)
    model_state_dict = model.state_dict()

    # Load the saved state dict
    saved_state_dict = torch.load('best_model_test_acc.pth')

    # Filter the state dict to match the model
    filtered_state_dict = filter_state_dict(saved_state_dict, model_state_dict)
    model.load_state_dict(filtered_state_dict, strict=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    print("Model loaded with filtered state_dict.")

  saved_state_dict = torch.load('best_model_test_acc.pth')


Model loaded with filtered state_dict.


In [10]:
if __name__ == "__main__":
    # 1. 加载模型
    model = ImprovedTwoStreamModel(num_classes=2, seg_num=45)
    
    # 2. 加载权重（使用非严格模式）
    state_dict = torch.load('best_model_test_acc.pth')
    
    # 3. 处理权重键名不匹配的问题
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('gated_fusion.gate.2'):
            # 将旧的gate.2改为gate.3
            new_k = k.replace('gate.2', 'gate.3')
            new_state_dict[new_k] = v
        else:
            new_state_dict[k] = v
    
    # 4. 使用非严格模式加载
    model.load_state_dict(new_state_dict, strict=False)
    print("Model loaded successfully with non-strict mode")
    
    # 5. 移动到设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # 6. 设置为评估模式
    model.eval()
    
    # 7. 创建数据集
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = VideoDataset(
        root_dir="hmdb_data_demo",
        seg_num=45,
        transform=transform,
        split='test'
    )
    
    # 8. 创建可视化器
    visualizer = CAMVisualizer(model, device)
    
    # 9. 可视化样本
    try:
        visualize_dataset_samples(model, test_dataset, num_samples=6)
        print("Visualization completed successfully")
    except Exception as e:
        print(f"Error during visualization: {e}")

  state_dict = torch.load('best_model_test_acc.pth')


Model loaded successfully with non-strict mode

成功加载数据 (test):
总样本数 (videos): 17
feixianhua 样本: 9
xianhua 样本: 8
[Gate] Mean Weights - Spatial: 0.5340, Temporal: 0.4660, 
CAM visualization saved to cam_visualizations/sequence_4_class_0.png
[Gate] Mean Weights - Spatial: 0.5279, Temporal: 0.4721, 
CAM visualization saved to cam_visualizations/sequence_15_class_1.png
[Gate] Mean Weights - Spatial: 0.5306, Temporal: 0.4694, 
CAM visualization saved to cam_visualizations/sequence_14_class_1.png
[Gate] Mean Weights - Spatial: 0.5247, Temporal: 0.4753, 
CAM visualization saved to cam_visualizations/sequence_10_class_1.png
[Gate] Mean Weights - Spatial: 0.5292, Temporal: 0.4708, 
CAM visualization saved to cam_visualizations/sequence_6_class_0.png
[Gate] Mean Weights - Spatial: 0.5280, Temporal: 0.4720, 
CAM visualization saved to cam_visualizations/sequence_12_class_1.png
Visualization completed successfully


In [10]:
from sklearn.model_selection import KFold
import torch
import numpy as np
from sklearn.metrics import roc_curve, auc, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from torch.cuda.amp import autocast, GradScaler


class TrainerWithKFold:
    def __init__(self, config, model, dataset, criterion, device, output_dir="results"):
        self.config = config
        self.model = model
        self.dataset = dataset
        self.criterion = criterion
        self.device = device
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

        # 初始化优化器、学习率调度器和Scaler
        self.optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=config.BASE_LR,
            weight_decay=config.WEIGHT_DECAY
        )
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=config.LR_DECAY_EPOCHS,
            gamma=config.LR_DECAY_RATE
        )
        self.scaler = GradScaler()

        # 初始化历史记录
        self.history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'test_loss': [], 'test_acc': [], 'test_f1': [],
            'lr': []
        }

    def train_and_evaluate_kfold(self, k=5):
        kfold = KFold(n_splits=k, shuffle=True, random_state=self.config.SEED)
        fold_results = []

        for fold, (train_indices, val_indices) in enumerate(kfold.split(self.dataset)):
            print(f"=== Fold {fold + 1}/{k} ===")
            
            # 创建训练和验证数据加载器
            train_loader = torch.utils.data.DataLoader(
                self.dataset, batch_size=self.config.BATCH_SIZE,
                sampler=torch.utils.data.SubsetRandomSampler(train_indices)
            )
            val_loader = torch.utils.data.DataLoader(
                self.dataset, batch_size=self.config.BATCH_SIZE,
                sampler=torch.utils.data.SubsetRandomSampler(val_indices)
            )

            # 训练和验证
            best_val_acc = 0.0
            for epoch in range(self.config.NUM_EPOCHS):
                train_loss, train_acc, train_f1 = self.train_epoch(train_loader, epoch)
                val_metrics = self.evaluate(val_loader, mode='val')

                # 更新学习率
                self.scheduler.step()

                # 保存最佳模型
                if val_metrics['accuracy'] > best_val_acc:
                    best_val_acc = val_metrics['accuracy']
                    torch.save(self.model.state_dict(), os.path.join(self.output_dir, f"best_model_fold_{fold + 1}.pth"))
                    print(f"New best model for Fold {fold + 1} saved with val accuracy: {best_val_acc:.4f}")

                # 打印结果
                self._print_epoch_results(epoch, train_loss, train_acc, train_f1, val_metrics)

            # 保存每个Fold的结果
            fold_results.append(val_metrics)

        # 汇总所有Fold的结果
        self._summarize_kfold_results(fold_results)

    def train_epoch(self, train_loader, epoch):
        self.model.train()
        running_loss = 0.0
        all_labels = []
        all_preds = []

        for batch_idx, (rgb, flow, labels) in enumerate(train_loader):
            rgb, flow, labels = rgb.to(self.device), flow.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()

            with autocast():
                outputs = self.model(rgb, flow, labels=labels, mode='train')
                loss_inputs = {
                    'logits': outputs['logits'],
                    'spatial_feats': outputs.get('spatial_feats'),
                    'temporal_feats': outputs.get('temporal_feats'),
                    'attn_weights': outputs.get('attn_weights'),
                    'gate_weights': outputs.get('gate_weights')
                }
                loss = self.criterion(loss_inputs, labels)

            self.scaler.scale(loss['total']).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            running_loss += loss['total'].item()
            _, predicted = torch.max(outputs['logits'], 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_f1 = f1_score(all_labels, all_preds, average='binary')

        return epoch_loss, epoch_acc, epoch_f1

    def evaluate(self, loader, mode='val'):
        self.model.eval()
        running_loss = 0.0
        all_labels = []
        all_preds = []
        all_probs = []

        with torch.no_grad():
            for rgb, flow, labels in loader:
                rgb, flow, labels = rgb.to(self.device), flow.to(self.device), labels.to(self.device)

                with autocast():
                    outputs = self.model(rgb, flow, mode='eval')
                    loss = self.criterion({'logits': outputs['logits']}, labels)

                running_loss += loss['total'].item()
                probs = torch.softmax(outputs['logits'], dim=1)
                _, predicted = torch.max(outputs['logits'], 1)

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())

        metrics = {
            'loss': running_loss / len(loader),
            'accuracy': accuracy_score(all_labels, all_preds),
            'f1': f1_score(all_labels, all_preds, average='binary'),
            'precision': precision_score(all_labels, all_preds, average='binary'),
            'recall': recall_score(all_labels, all_preds, average='binary')
        }

        # ROC and PR metrics
        try:
            fpr, tpr, _ = roc_curve(all_labels, all_probs)
            metrics['roc_auc'] = auc(fpr, tpr)

            precision_pts, recall_pts, _ = precision_recall_curve(all_labels, all_probs)
            metrics['pr_auc'] = auc(recall_pts, precision_pts)
        except ValueError:
            metrics['roc_auc'] = 0.0
            metrics['pr_auc'] = 0.0

        return metrics

    def _summarize_kfold_results(self, fold_results):
        print("\n=== K-Fold Cross Validation Results ===")
        summary = {key: [] for key in fold_results[0].keys()}

        # Collect metrics from all folds
        for fold_idx, fold_metrics in enumerate(fold_results):
            print(f"Fold {fold_idx + 1}:")
            for metric_name, value in fold_metrics.items():
                print(f"  {metric_name}: {value:.4f}")
                summary[metric_name].append(value)

        # Compute mean and std for each metric
        print("\nOverall Results:")
        for metric_name, values in summary.items():
            mean_val = np.mean(values)
            std_val = np.std(values)
            print(f"  {metric_name}: {mean_val:.4f} ± {std_val:.4f}")

    def _print_epoch_results(self, epoch, train_loss, train_acc, train_f1, val_metrics):
        print(f"\nEpoch {epoch + 1}/{self.config.NUM_EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, Val F1: {val_metrics['f1']:.4f}")

    def plot_training_progress(self):
        plt.figure(figsize=(15, 10))
        
        # Plot loss
        plt.subplot(2, 2, 1)
        plt.plot(self.history['train_loss'], label='Train')
        plt.plot(self.history['test_loss'], label='Test')
        plt.title('Loss vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        # Plot accuracy
        plt.subplot(2, 2, 2)
        plt.plot(self.history['train_acc'], label='Train')
        plt.plot(self.history['test_acc'], label='Test')
        plt.title('Accuracy vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        
        # Plot F1 Score
        plt.subplot(2, 2, 3)
        plt.plot(self.history['train_f1'], label='Train')
        plt.plot(self.history['test_f1'], label='Test')
        plt.title('F1 Score vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.legend()
        plt.grid(True)
        
        # Plot learning rate
        plt.subplot(2, 2, 4)
        plt.plot(self.history['lr'])
        plt.title('Learning Rate vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.yscale('log')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'{self.output_dir}/training_progress.png')
        plt.close()

In [15]:
if __name__ == "__main__":
    # 加载配置
    config = Config()
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 数据转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 创建数据加载器
    try:
        dataset = VideoDataset(
            config.ROOT_DATA_DIR, 
            seg_num=config.SEG_NUM,
            transform=transform,
            train_ratio=config.TRAIN_RATIO,
            test_ratio=config.TEST_RATIO
        )
    except Exception as e:
        print(f"Error creating dataset: {e}")
        exit(1)
    
    # 创建模型和损失函数
    model = ImprovedTwoStreamModelWithLoss(
        num_classes=config.NUM_CLASSES,
        seg_num=config.SEG_NUM,
        d_model=config.D_MODEL,
        nhead=config.NHEAD
    ).to(device)
    
    criterion = EnhancedMultiTaskLoss(
        alpha=config.ALPHA,
        beta=config.BETA,
        gamma=config.GAMMA,
        kesal=config.KESAL
    )
    
    # 创建训练器并执行 K 折交叉验证
    trainer = TrainerWithKFold(config, model, dataset, criterion, device, output_dir="results")
    trainer.train_and_evaluate_kfold(k=config.K_FOLDS)
    
    print("\nK-Fold Cross Validation completed.")


成功加载数据 (train):
总样本数 (videos): 38
feixianhua 样本: 20
xianhua 样本: 18


  self.scaler = GradScaler()


=== Fold 1/3 ===


  with autocast():


Gate Weights - Spatial: 0.5931, Temporal: 0.4069, Balance: 0.5000
Gate Weights - Spatial: 0.5939, Temporal: 0.4061, Balance: 0.5000
Gate Weights - Spatial: 0.5944, Temporal: 0.4056, Balance: 0.5000
Gate Weights - Spatial: 0.5937, Temporal: 0.4063, Balance: 0.5000
Gate Weights - Spatial: 0.5961, Temporal: 0.4039, Balance: 0.5000
Gate Weights - Spatial: 0.5955, Temporal: 0.4045, Balance: 0.5000
Gate Weights - Spatial: 0.5922, Temporal: 0.4078, Balance: 0.5000
Gate Weights - Spatial: 0.5925, Temporal: 0.4075, Balance: 0.5000
Gate Weights - Spatial: 0.4073, Temporal: 0.5927, Balance: 0.5001
Gate Weights - Spatial: 0.4073, Temporal: 0.5927, Balance: 0.5001
Gate Weights - Spatial: 0.3679, Temporal: 0.6321, Balance: 0.5002
Gate Weights - Spatial: 0.3678, Temporal: 0.6322, Balance: 0.5002
Gate Weights - Spatial: 0.3679, Temporal: 0.6321, Balance: 0.5002
Gate Weights - Spatial: 0.3679, Temporal: 0.6321, Balance: 0.5002
Gate Weights - Spatial: 0.3537, Temporal: 0.6463, Balance: 0.5002
Gate Weigh

  with autocast():


New best model for Fold 1 saved with val accuracy: 0.3846

Epoch 1/30
Train Loss: 14.5717, Train Acc: 0.3600, Train F1: 0.3333
Val Loss: 2.1044, Val Acc: 0.3846, Val F1: 0.5000


  with autocast():


Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5007
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5007
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3495, Temporal: 0.6505, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weigh

  with autocast():


New best model for Fold 1 saved with val accuracy: 0.4615

Epoch 2/30
Train Loss: 2.5895, Train Acc: 0.5200, Train F1: 0.5714
Val Loss: 2.2817, Val Acc: 0.4615, Val F1: 0.6316


  with autocast():


Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


New best model for Fold 1 saved with val accuracy: 0.5385

Epoch 3/30
Train Loss: 2.5251, Train Acc: 0.4400, Train F1: 0.3636
Val Loss: 2.1664, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4/30
Train Loss: 3.5916, Train Acc: 0.4000, Train F1: 0.3478
Val Loss: 3.0646, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5006
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5/30
Train Loss: 2.5968, Train Acc: 0.4800, Train F1: 0.4348
Val Loss: 2.3214, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weights - Spatial: 0.3496, Temporal: 0.6504, Balance: 0.5005
Gate Weigh

  with autocast():



Epoch 6/30
Train Loss: 2.4076, Train Acc: 0.3200, Train F1: 0.2609
Val Loss: 2.1282, Val Acc: 0.4615, Val F1: 0.6316


  with autocast():


Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5005
Gate Weigh

  with autocast():



Epoch 7/30
Train Loss: 2.2648, Train Acc: 0.4000, Train F1: 0.4444
Val Loss: 2.1267, Val Acc: 0.4615, Val F1: 0.6316


  with autocast():


Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weights - Spatial: 0.3497, Temporal: 0.6503, Balance: 0.5004
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 8/30
Train Loss: 2.4525, Train Acc: 0.4400, Train F1: 0.4167
Val Loss: 2.0941, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5003
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 9/30
Train Loss: 2.3910, Train Acc: 0.4400, Train F1: 0.3000
Val Loss: 2.1418, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3498, Temporal: 0.6502, Balance: 0.5002
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5002
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5002
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5002
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 10/30
Train Loss: 2.7290, Train Acc: 0.4800, Train F1: 0.4800
Val Loss: 2.1190, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3499, Temporal: 0.6501, Balance: 0.5001
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.5001
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 11/30
Train Loss: 2.2625, Train Acc: 0.4400, Train F1: 0.5333
Val Loss: 2.4054, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3500, Temporal: 0.6500, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weights - Spatial: 0.3501, Temporal: 0.6499, Balance: 0.4999
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 12/30
Train Loss: 2.4382, Train Acc: 0.4800, Train F1: 0.4348
Val Loss: 2.2440, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4998
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4997
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4997
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4997
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4997
Gate Weights - Spatial: 0.3502, Temporal: 0.6498, Balance: 0.4997
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 13/30
Train Loss: 2.4902, Train Acc: 0.4400, Train F1: 0.4615
Val Loss: 2.7768, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3503, Temporal: 0.6497, Balance: 0.4995
Gate Weights - Spatial: 0.3503, Temporal: 0.6497, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weights - Spatial: 0.3504, Temporal: 0.6496, Balance: 0.4995
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 14/30
Train Loss: 2.8825, Train Acc: 0.4000, Train F1: 0.4000
Val Loss: 3.0692, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4992
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weights - Spatial: 0.3506, Temporal: 0.6494, Balance: 0.4991
Gate Weigh

  with autocast():



Epoch 15/30
Train Loss: 2.2699, Train Acc: 0.6000, Train F1: 0.5000
Val Loss: 2.2586, Val Acc: 0.4615, Val F1: 0.6316


  with autocast():


Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4990
Gate Weigh

  with autocast():



Epoch 16/30
Train Loss: 2.0431, Train Acc: 0.4800, Train F1: 0.6486
Val Loss: 2.1295, Val Acc: 0.4615, Val F1: 0.6316


  with autocast():


Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weights - Spatial: 0.3507, Temporal: 0.6493, Balance: 0.4989
Gate Weigh

  with autocast():



Epoch 17/30
Train Loss: 1.9391, Train Acc: 0.8400, Train F1: 0.8462
Val Loss: 2.0784, Val Acc: 0.4615, Val F1: 0.4615


  with autocast():


Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 18/30
Train Loss: 1.8827, Train Acc: 1.0000, Train F1: 1.0000
Val Loss: 2.0646, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4989
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 19/30
Train Loss: 1.8673, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0699, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weights - Spatial: 0.3508, Temporal: 0.6492, Balance: 0.4988
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 20/30
Train Loss: 1.8242, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0737, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4988
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 21/30
Train Loss: 1.7734, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0693, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 22/30
Train Loss: 1.7669, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0725, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 23/30
Train Loss: 1.7537, Train Acc: 0.9600, Train F1: 0.9565
Val Loss: 2.0813, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 24/30
Train Loss: 1.7474, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0713, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 25/30
Train Loss: 1.7334, Train Acc: 0.9600, Train F1: 0.9565
Val Loss: 2.0726, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 26/30
Train Loss: 1.7309, Train Acc: 0.9600, Train F1: 0.9565
Val Loss: 2.0740, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 27/30
Train Loss: 1.7237, Train Acc: 0.9600, Train F1: 0.9565
Val Loss: 2.0759, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 28/30
Train Loss: 1.7178, Train Acc: 0.9200, Train F1: 0.9091
Val Loss: 2.0799, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 29/30
Train Loss: 1.6967, Train Acc: 1.0000, Train F1: 1.0000
Val Loss: 2.0769, Val Acc: 0.5385, Val F1: 0.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 30/30
Train Loss: 1.6892, Train Acc: 1.0000, Train F1: 1.0000
Val Loss: 2.0752, Val Acc: 0.5385, Val F1: 0.2500
=== Fold 2/3 ===


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():


New best model for Fold 2 saved with val accuracy: 0.9231

Epoch 1/30
Train Loss: 1.8916, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8887, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 2/30
Train Loss: 1.8925, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8957, Val Acc: 0.7692, Val F1: 0.5714


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 3/30
Train Loss: 1.8911, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8874, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 4/30
Train Loss: 1.8880, Train Acc: 0.7200, Train F1: 0.6957
Val Loss: 1.8849, Val Acc: 0.7692, Val F1: 0.5714


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 5/30
Train Loss: 1.8895, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8793, Val Acc: 0.7692, Val F1: 0.5714


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 6/30
Train Loss: 1.8890, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.9019, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 7/30
Train Loss: 1.8870, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8913, Val Acc: 0.8462, Val F1: 0.7500


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 8/30
Train Loss: 1.8836, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8827, Val Acc: 0.8462, Val F1: 0.7500


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 9/30
Train Loss: 1.8882, Train Acc: 0.6400, Train F1: 0.6400
Val Loss: 1.8916, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 10/30
Train Loss: 1.8885, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.8961, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():


New best model for Fold 2 saved with val accuracy: 1.0000

Epoch 11/30
Train Loss: 1.8911, Train Acc: 0.6800, Train F1: 0.6667
Val Loss: 1.9075, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 12/30
Train Loss: 1.8845, Train Acc: 0.6400, Train F1: 0.6400
Val Loss: 1.8948, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 13/30
Train Loss: 1.8843, Train Acc: 0.6400, Train F1: 0.6400
Val Loss: 1.8882, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 14/30
Train Loss: 1.8838, Train Acc: 0.6400, Train F1: 0.6400
Val Loss: 1.8981, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 15/30
Train Loss: 1.8819, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.8981, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 16/30
Train Loss: 1.8841, Train Acc: 0.6400, Train F1: 0.6667
Val Loss: 1.8842, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 17/30
Train Loss: 1.8749, Train Acc: 0.7600, Train F1: 0.7692
Val Loss: 1.8860, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 18/30
Train Loss: 1.8816, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.8983, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 19/30
Train Loss: 1.8794, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.9037, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 20/30
Train Loss: 1.8819, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.8936, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 21/30
Train Loss: 1.8799, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.9046, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 22/30
Train Loss: 1.8805, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.9034, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 23/30
Train Loss: 1.8737, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.8875, Val Acc: 0.9231, Val F1: 0.8889


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 24/30
Train Loss: 1.8755, Train Acc: 0.6400, Train F1: 0.6400
Val Loss: 1.9206, Val Acc: 0.9231, Val F1: 0.9091


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 25/30
Train Loss: 1.8784, Train Acc: 0.6800, Train F1: 0.6923
Val Loss: 1.9016, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 26/30
Train Loss: 1.8790, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.9032, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 27/30
Train Loss: 1.8768, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.9122, Val Acc: 0.9231, Val F1: 0.9091


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 28/30
Train Loss: 1.8758, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.9014, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 29/30
Train Loss: 1.8687, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.9146, Val Acc: 0.9231, Val F1: 0.9091


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 30/30
Train Loss: 1.8650, Train Acc: 0.7200, Train F1: 0.7407
Val Loss: 1.9118, Val Acc: 0.9231, Val F1: 0.9091
=== Fold 3/3 ===


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():


New best model for Fold 3 saved with val accuracy: 0.9167

Epoch 1/30
Train Loss: 1.8711, Train Acc: 0.7308, Train F1: 0.6957
Val Loss: 1.9702, Val Acc: 0.9167, Val F1: 0.9231


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 2/30
Train Loss: 1.8724, Train Acc: 0.6923, Train F1: 0.6667
Val Loss: 1.9841, Val Acc: 0.8333, Val F1: 0.8750


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():


New best model for Fold 3 saved with val accuracy: 1.0000

Epoch 3/30
Train Loss: 1.8708, Train Acc: 0.7308, Train F1: 0.6957
Val Loss: 1.9885, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 4/30
Train Loss: 1.8755, Train Acc: 0.6923, Train F1: 0.6364
Val Loss: 1.9679, Val Acc: 0.9167, Val F1: 0.9231


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 5/30
Train Loss: 1.8686, Train Acc: 0.6538, Train F1: 0.5714
Val Loss: 1.9777, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 6/30
Train Loss: 1.8668, Train Acc: 0.6923, Train F1: 0.6364
Val Loss: 1.9847, Val Acc: 0.8333, Val F1: 0.8750


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 7/30
Train Loss: 1.8688, Train Acc: 0.7308, Train F1: 0.6667
Val Loss: 1.9720, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 8/30
Train Loss: 1.8646, Train Acc: 0.6923, Train F1: 0.6364
Val Loss: 1.9802, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 9/30
Train Loss: 1.8698, Train Acc: 0.7308, Train F1: 0.6667
Val Loss: 1.9872, Val Acc: 1.0000, Val F1: 1.0000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 10/30
Train Loss: 1.8638, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 2.0092, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 11/30
Train Loss: 1.8627, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 1.9951, Val Acc: 0.8333, Val F1: 0.8750


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 12/30
Train Loss: 1.8619, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 2.0055, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 13/30
Train Loss: 1.8582, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 2.0051, Val Acc: 0.6667, Val F1: 0.6000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 14/30
Train Loss: 1.8612, Train Acc: 0.7308, Train F1: 0.6316
Val Loss: 1.9689, Val Acc: 0.9167, Val F1: 0.9231


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 15/30
Train Loss: 1.8603, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 2.0061, Val Acc: 0.6667, Val F1: 0.6000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 16/30
Train Loss: 1.8607, Train Acc: 0.7308, Train F1: 0.6316
Val Loss: 1.9901, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 17/30
Train Loss: 1.8604, Train Acc: 0.7308, Train F1: 0.6316
Val Loss: 2.0027, Val Acc: 0.6667, Val F1: 0.6000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 18/30
Train Loss: 1.8595, Train Acc: 0.6923, Train F1: 0.6000
Val Loss: 1.9885, Val Acc: 0.5000, Val F1: 0.2500


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 19/30
Train Loss: 1.8510, Train Acc: 0.7692, Train F1: 0.6667
Val Loss: 1.9863, Val Acc: 0.5000, Val F1: 0.2500


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 20/30
Train Loss: 1.8534, Train Acc: 0.7692, Train F1: 0.6667
Val Loss: 2.0106, Val Acc: 0.6667, Val F1: 0.6000


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4987
Gate Weigh

  with autocast():



Epoch 21/30
Train Loss: 1.8519, Train Acc: 0.6923, Train F1: 0.5556
Val Loss: 1.9761, Val Acc: 0.7500, Val F1: 0.7273


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 22/30
Train Loss: 1.8502, Train Acc: 0.8077, Train F1: 0.7059
Val Loss: 2.0065, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 23/30
Train Loss: 1.8442, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 1.9963, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 24/30
Train Loss: 1.8505, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 1.9935, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 25/30
Train Loss: 1.8521, Train Acc: 0.7308, Train F1: 0.5882
Val Loss: 2.0054, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 26/30
Train Loss: 1.8487, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 1.9874, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 27/30
Train Loss: 1.8454, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 1.9854, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 28/30
Train Loss: 1.8519, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 2.0077, Val Acc: 0.5833, Val F1: 0.4444


  with autocast():


Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weights - Spatial: 0.3509, Temporal: 0.6491, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 29/30
Train Loss: 1.8451, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 1.9905, Val Acc: 0.5000, Val F1: 0.2500


  with autocast():


Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weights - Spatial: 0.3510, Temporal: 0.6490, Balance: 0.4986
Gate Weigh

  with autocast():



Epoch 30/30
Train Loss: 1.8432, Train Acc: 0.7692, Train F1: 0.6250
Val Loss: 2.0155, Val Acc: 0.5833, Val F1: 0.4444

=== K-Fold Cross Validation Results ===
Fold 1:
  loss: 2.0752
  accuracy: 0.5385
  f1: 0.2500
  precision: 0.5000
  recall: 0.1667
  roc_auc: 0.5000
  pr_auc: 0.5370
Fold 2:
  loss: 1.9118
  accuracy: 0.9231
  f1: 0.9091
  precision: 0.8333
  recall: 1.0000
  roc_auc: 1.0000
  pr_auc: 1.0000
Fold 3:
  loss: 2.0155
  accuracy: 0.5833
  f1: 0.4444
  precision: 1.0000
  recall: 0.2857
  roc_auc: 1.0000
  pr_auc: 1.0000

Overall Results:
  loss: 2.0009 ± 0.0675
  accuracy: 0.6816 ± 0.1717
  f1: 0.5345 ± 0.2765
  precision: 0.7778 ± 0.2079
  recall: 0.4841 ± 0.3680
  roc_auc: 0.8333 ± 0.2357
  pr_auc: 0.8457 ± 0.2183

K-Fold Cross Validation completed.
