# MLP

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import os
import h5py
from skimage.transform import resize
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

In [None]:
# --- Configuration ---
NUM_EPOCHS_CENTRALIZED = 50
NUM_CLASSES = 4
LEARNING_RATE = 1e-5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 256
BATCH_SIZE = 8 

In [None]:
# --- Standard Convolutional Block ---
class BasicConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, use_bn=True):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

# ePURE

In [None]:
# --- ePURE Implementation (Provided) ---
class ePURE(nn.Module):
    def __init__(self, in_channels, base_channels=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, 1, 3, padding=1) # Ensure output is 1 channel for noise profile
        )

    def forward(self, x):
        x_float = x.float()

        # Estimate a base noise map from the input features
        noise_map_raw = self.conv(x_float) # Output is [B, 1, H, W]

        # Simple approach: just output the learned map directly.
        # The adaptive smoothing uses sigmoid, so the network learns to output values
        # that sigmoid can map to appropriate blending weights.
        noise_map = noise_map_raw # [B, 1, H, W]

        return noise_map # Noise profile estimate (1 channel)

# Adaptive_Spline_Function

In [None]:
import torchvision.transforms.functional as TF
# --- Adaptive Spline Smoothing Implementation (Provided) ---
def adaptive_spline_smoothing(x, noise_profile, kernel_size=5, sigma=1.0):
    """
    Áp dụng làm mịn thích nghi dựa trên noise_profile
    - x: Ảnh đầu vào hoặc feature map [B, C, H, W]
    - noise_profile: Bản đồ nhiễu [B, 1, H, W] (giá trị từ 0 đến 1)
    - kernel_size/sigma: Tham số làm mịn Gaussian
    """
    # Ensure input is float for convolution
    x_float = x.float()

    # Ensure noise_profile is float and 1 channel
    noise_profile_float = noise_profile.float()
    if noise_profile_float.size(1) != 1:
         print(f"Warning: Noise profile expected 1 channel but got {noise_profile_float.size(1)}. Using first channel.")
         noise_profile_float = noise_profile_float[:, :1, :, :]


    # Bước 1: Làm mịn ảnh bằng Gaussian blur
    # Apply Gaussian blur channel-wise
    # kernel_size can be a single int or a tuple (h, w). sigma same.
    # Ensure kernel_size is a tuple if needed, or check F.gaussian_blur docs.
    # F.gaussian_blur expects kernel_size as a tuple of ints (h, w).
    # If kernel_size is an int, it uses that for both dims.
    if isinstance(kernel_size, int):
        kernel_size_tuple = (kernel_size, kernel_size)
    else:
        kernel_size_tuple = kernel_size

    if isinstance(sigma, (int, float)):
         sigma_tuple = (float(sigma), float(sigma))
    else:
         sigma_tuple = sigma

    # Ensure sigma values are positive to avoid issues
    sigma_tuple = tuple(max(0.1, s) for s in sigma_tuple) # Add small epsilon

    smoothed = TF.gaussian_blur(x_float, kernel_size=kernel_size_tuple, sigma=sigma_tuple)

    # Bước 2: Chuẩn hóa noise_profile (sigmoid) và mở rộng cho đúng số kênh
    # Sigmoid ensures blending weights are between 0 and 1
    # A higher noise_profile value should lead to *more* smoothing.
    # So, blending_weights = noise_profile (after sigmoid)
    blending_weights = torch.sigmoid(noise_profile_float) # [B, 1, H, W]

    # Expand blending_weights to match the number of channels in x
    blending_weights = blending_weights.repeat(1, x_float.size(1), 1, 1) # [B, C, H, W]

    # Ensure dimensions match for blending
    assert blending_weights.shape == x_float.shape, f"Blending weights shape {blending_weights.shape} does not match input shape {x_float.shape}"

    # Bước 3: Trộn ảnh gốc và ảnh đã làm mịn
    # Output = (1 - alpha) * Original + alpha * Smoothed
    # where alpha = blending_weights
    weighted_sum = x_float * (1 - blending_weights) + smoothed * blending_weights

    return weighted_sum

In [None]:
def quantum_noise_injection(features, T=1.25, pauli_prob={'X': 0.00096, 'Y': 0.00096, 'Z': 0.00096, 'None': 0.99712}):
    """
    Áp dụng nhiễu lượng tử dựa trên cơ chế Pauli Noise Injection cho dữ liệu ảnh MRI.
    
    Args:
        features (torch.Tensor): Tensor đầu vào dạng (batch_size, channels, height, width).
        T (float): Hệ số nhiễu, thường trong khoảng [0.5, 1.5].
        pauli_prob (dict): Phân phối xác suất cho các cổng Pauli (X, Y, Z, None).
    
    Returns:
        torch.Tensor: Tensor đầu ra với nhiễu lượng tử được áp dụng.
    """
    # Chuyển features sang kiểu float
    features_float = features.float()
    
    # Kiểm tra kích thước tensor
    if features_float.dim() < 4 or features_float.size(2) < 2 or features_float.size(3) < 2:
        print("Warning: Features too small for quantum noise injection.")
        return features_float

    try:
        # Đảm bảo tensor ở trên thiết bị đúng
        device = features_float.device
        
        # Chuẩn hóa xác suất Pauli với hệ số T
        scaled_prob = {
            'X': pauli_prob['X'] * T,
            'Y': pauli_prob['Y'] * T,
            'Z': pauli_prob['Z'] * T,
            'None': 1.0 - (pauli_prob['X'] + pauli_prob['Y'] + pauli_prob['Z']) * T
        }
        
        # Tạo mặt nạ ngẫu nhiên để chọn cổng Pauli
        batch_size, channels, height, width = features_float.shape
        pauli_choices = ['X', 'Y', 'Z', 'None']
        probabilities = [scaled_prob['X'], scaled_prob['Y'], scaled_prob['Z'], scaled_prob['None']]
        choice_tensor = torch.multinomial(
            torch.tensor(probabilities, device=device),
            batch_size * channels * height * width,
            replacement=True
        ).view(batch_size, channels, height, width)
        
        # Khởi tạo tensor đầu ra
        noisy_features = features_float.clone()
        
        # Áp dụng cổng Pauli
        for i, pauli in enumerate(pauli_choices):
            mask = (choice_tensor == i)
            if pauli == 'X':
                # Cổng Pauli X: Lật giá trị pixel (giả sử giá trị đã chuẩn hóa trong [0, 1])
                noisy_features[mask] = 1.0 - noisy_features[mask]
            elif pauli == 'Y':
                # Cổng Pauli Y: Kết hợp lật bit và thêm nhiễu ngẫu nhiên
                noisy_features[mask] = 1.0 - noisy_features[mask] + 0.1 * torch.randn_like(noisy_features[mask], device=device)
            elif pauli == 'Z':
                # Cổng Pauli Z: Đổi dấu giá trị pixel
                noisy_features[mask] = -noisy_features[mask]
            # 'None': Giữ nguyên giá trị
            
        # Đảm bảo giá trị pixel nằm trong phạm vi [0, 1]
        noisy_features = torch.clamp(noisy_features, 0.0, 1.0)
        
        return noisy_features
    
    except RuntimeError as e:
        print(f"Quantum noise injection failed: {e}. Returning original features.")
        return features_float

In [None]:
# --- Model Components (U-Net based) ---
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block1 = BasicConvBlock(in_channels, out_channels)
        self.conv_block2 = BasicConvBlock(out_channels, out_channels)
        self.noise_estimator = ePURE(in_channels=in_channels)

    def forward(self, x):
        noise_profile = self.noise_estimator(x)
        x_smoothed = adaptive_spline_smoothing(x, noise_profile)
        x = self.conv_block1(x_smoothed)
        x = self.conv_block2(x)
        return x

In [None]:
class MaxwellSolver(nn.Module):
    def __init__(self, in_channels, hidden_dim=32):
        super(MaxwellSolver, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(hidden_dim, 2, kernel_size=3, padding=1))
        omega, mu_0, eps_0 = 2 * np.pi * 42.58e6, 4 * np.pi * 1e-7, 8.854187817e-12
        self.k0 = torch.tensor(omega * np.sqrt(mu_0 * eps_0), dtype=torch.float32)

    def forward(self, x):
        eps_sigma_map = self.encoder(x)
        return eps_sigma_map[:, 0:1, :, :], eps_sigma_map[:, 1:2, :, :]

    def compute_helmholtz_residual(self, b1_map, eps, sigma):
        self.k0 = self.k0.to(b1_map.device)
        omega = 2 * np.pi * 42.58e6
        b1_map_complex = torch.complex(b1_map, torch.zeros_like(b1_map)) if not b1_map.is_complex() else b1_map
        eps_r, sig_r = eps.to(b1_map_complex.device), sigma.to(b1_map_complex.device)
        size = b1_map_complex.shape[2:]
        up_eps = F.interpolate(eps_r, size=size, mode='bilinear', align_corners=False)
        up_sig = F.interpolate(sig_r, size=size, mode='bilinear', align_corners=False)
        eps_c = torch.complex(up_eps, -up_sig / omega)
        lap_b1 = self._laplacian_2d(b1_map_complex)
        res = lap_b1 + (self.k0 ** 2) * eps_c * b1_map_complex
        return res.real ** 2 + res.imag ** 2

    def _laplacian_2d(self, x_complex):
        k = torch.tensor([[0.,1.,0.],[1.,-4.,1.],[0.,1.,0.]], device=x_complex.device).reshape(1,1,3,3)
        # Handle cases where real or imag part might have 0 channels if x_complex is purely real/imag
        groups_real = x_complex.real.size(1) if x_complex.real.size(1) > 0 else 1
        groups_imag = x_complex.imag.size(1) if x_complex.imag.size(1) > 0 else 1

        real_lap = F.conv2d(x_complex.real, k.repeat(groups_real,1,1,1) if groups_real > 0 else k, padding=1, groups=groups_real)
        imag_lap = F.conv2d(x_complex.imag, k.repeat(groups_imag,1,1,1) if groups_imag > 0 else k, padding=1, groups=groups_imag)
        return torch.complex(real_lap, imag_lap)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        concat_ch = in_channels // 2 + skip_channels
        self.maxwell_solver = MaxwellSolver(concat_ch)
        self.conv_block1 = BasicConvBlock(concat_ch, out_channels)
        self.conv_block2 = BasicConvBlock(out_channels, out_channels)

    def forward(self, x, skip_connection):
        x = self.up(x)
        diffY, diffX = skip_connection.size()[2]-x.size()[2], skip_connection.size()[3]-x.size()[3]
        x = F.pad(x, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
        x_cat = torch.cat([skip_connection, x], dim=1)
        es_tuple = self.maxwell_solver(x_cat)
        out = self.conv_block1(x_cat)
        out = self.conv_block2(out)
        return out, es_tuple

In [None]:
class RobustMedVFL_UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=4):
        super().__init__()
        self.enc1, self.pool1 = EncoderBlock(n_channels, 64), nn.MaxPool2d(2)
        self.enc2, self.pool2 = EncoderBlock(64, 128), nn.MaxPool2d(2)
        self.enc3, self.pool3 = EncoderBlock(128, 256), nn.MaxPool2d(2)
        self.enc4, self.pool4 = EncoderBlock(256, 512), nn.MaxPool2d(2)
        self.bottleneck = EncoderBlock(512, 1024)
        self.dec1 = DecoderBlock(1024, 512, 512)
        self.dec2 = DecoderBlock(512, 256, 256)
        self.dec3 = DecoderBlock(256, 128, 128)
        self.dec4 = DecoderBlock(128, 64, 64)
        self.out_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        e1=self.enc1(x); p1=self.pool1(e1); e2=self.enc2(p1); p2=self.pool2(e2)
        e3=self.enc3(p2); p3=self.pool3(e3); e4=self.enc4(p3); p4=self.pool4(e4)
        b=self.bottleneck(p4)
        d1,es1=self.dec1(b,e4); d2,es2=self.dec2(d1,e3)
        d3,es3=self.dec3(d2,e2); d4,es4=self.dec4(d3,e1)
        return self.out_conv(d4), (es1, es2, es3, es4)

# Loss

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Optional, Dict, List

class Adaptive_tvmf_dice_loss(nn.Module):
    def __init__(
            self,
            num_classes: int = 4,
            lambda_val: float = 15.0,
            kappa_values=None,
            epsilon: float = 1e-6):
        super().__init__()
        self.num_classes = num_classes
        self.lambda_val = lambda_val
        self.epsilon = epsilon
        if kappa_values is not None:
            self.register_buffer('kappa_values', torch.tensor(kappa_values, dtype=torch.float32))
        else:
            self.register_buffer('kappa_values', torch.ones(num_classes) * lambda_val)
            
    def update_kappa_values(self, new_kappa_values) -> None:
        if isinstance(new_kappa_values, (list, np.ndarray)):
            new_kappa_values = torch.tensor(new_kappa_values, dtype=torch.float32)
        device = self.kappa_values.device
        self.kappa_values.data = new_kappa_values.to(device)
        
    def t_vmf_similarity(self, cos_theta, kappa):
        kappa = F.relu(kappa) + self.epsilon
        return torch.exp(kappa * (cos_theta - 1))
        
    def compute_dice_coefficient(self, pred, target):
        intersection = torch.sum(pred * target)
        union = torch.sum(pred) + torch.sum(target)
        dice = (2.0 * intersection + self.epsilon) / (union + self.epsilon)
        return dice
        
    def forward(self, inputs, targets):
        if inputs.dim() == 4:
            inputs = F.softmax(inputs, dim=1)
        else:
            inputs = F.softmax(inputs, dim=-1)
        if targets.dim() == 3:
            targets_one_hot = F.one_hot(targets.long(), num_classes=self.num_classes)
            targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
        else:
            targets_one_hot = targets.float()
        total_loss = 0.0
        class_losses = []
        for class_idx in range(self.num_classes):
            pred_class = inputs[:, class_idx, :, :]
            target_class = targets_one_hot[:, class_idx, :, :]
            pred_flat = pred_class.contiguous().view(-1)
            target_flat = target_class.contiguous().view(-1)
            if torch.sum(target_flat) < self.epsilon:
                class_losses.append(torch.tensor(0.0, device=inputs.device))
                continue
            cos_theta = F.cosine_similarity(pred_flat.unsqueeze(0), target_flat.unsqueeze(0), dim=1, eps=self.epsilon).squeeze()
            kappa_tensor = getattr(self, 'kappa_values')
            kappa = kappa_tensor[class_idx]
            similarity = self.t_vmf_similarity(cos_theta, kappa)
            dice_coeff = self.compute_dice_coefficient(pred_class, target_class)
            tvmf_loss = 1.0 - similarity
            dice_loss = 1.0 - dice_coeff
            class_loss = tvmf_loss + dice_loss
            class_losses.append(class_loss)
            total_loss += class_loss
        avg_loss = total_loss / self.num_classes
        self.last_class_losses = torch.stack(class_losses)
        return avg_loss
        
    def get_class_losses(self) -> Any:
        if hasattr(self, 'last_class_losses'):
            return self.last_class_losses.detach().cpu().numpy()
        return np.zeros(self.num_classes)
        
    def get_adaptive_info(self) -> Any:
        kappa_tensor = getattr(self, 'kappa_values')
        return {'kappa_values': kappa_tensor.detach().cpu().numpy().tolist(), 'lambda_val': self.lambda_val, 'num_classes': self.num_classes}


class PhysicsLoss(nn.Module):
    def __init__(self, in_channels_solver):
        super().__init__()
        self.ms = MaxwellSolver(in_channels_solver)
    def forward(self, b1, eps, sig):
        b, e, s = b1.to(DEVICE), eps.to(DEVICE), sig.to(DEVICE)
        return torch.mean(self.ms.compute_helmholtz_residual(b, e, s))


class SmoothnessLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        dy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        dx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
        return torch.mean(dy) + torch.mean(dx)


class AnatomicalRuleLoss(nn.Module):
    """
    Tính toán loss dựa trên quy tắc giải phẫu về vị trí tương đối của các vùng tim.
    - Phạt khi Tâm thất trái (LV) không được bao quanh bởi Cơ tim (MYO).
    - Phạt khi Tâm thất phải (RV) nằm cạnh Cơ tim (MYO).
    """
    def __init__(self, class_indices: Dict[str, int]):
        """
        Args:
            class_indices (Dict[str, int]): Dictionary ánh xạ tên class sang chỉ số.
                                          Cần chứa các key: 'LV', 'MYO', 'RV'.
        """
        super().__init__()
        if not all(k in class_indices for k in ['LV', 'MYO', 'RV']):
            raise ValueError("class_indices must contain keys 'LV', 'MYO', and 'RV'.")
        self.class_indices = class_indices

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Args:
            logits (torch.Tensor): Đầu ra raw từ model, shape (B, C, H, W).

        Returns:
            torch.Tensor: Giá trị loss vô hướng.
        """
        pred_probs = torch.softmax(logits, dim=1)
        
        # Lấy bản đồ xác suất cho từng class
        lv_prob = pred_probs[:, self.class_indices['LV']]
        myo_prob = pred_probs[:, self.class_indices['MYO']]
        rv_prob = pred_probs[:, self.class_indices['RV']]

        # Mô phỏng phép giãn nở (dilation) bằng max_pool2d để tìm vùng lân cận
        dilated_lv_prob = F.max_pool2d(lv_prob.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)
        dilated_rv_prob = F.max_pool2d(rv_prob.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1)

        # Phạt 1: Vùng bao quanh LV (dilated_lv_prob) không phải là MYO
        loss1 = dilated_lv_prob * (1 - myo_prob)

        # Phạt 2: Vùng RV tiếp xúc với MYO
        loss2 = dilated_rv_prob * myo_prob

        # Kết hợp và lấy trung bình
        total_rule_loss = torch.mean(loss1 + loss2)
        return total_rule_loss
        

class DynamicLossWeighter(nn.Module):
    """Dynamically adjusts weights for multiple loss components (e.g., CE, Dice, Physics)."""
    def __init__(self, num_losses: int, tau: float = 1.0, initial_weights: Optional[List[float]] = None):
        super().__init__()
        self.num_losses = num_losses
        self.tau = tau
        if initial_weights:
            assert len(initial_weights) == num_losses, "Number of initial weights must be equal to num_losses"
            weights = torch.tensor(initial_weights, dtype=torch.float32)
        else:
            weights = torch.ones(num_losses, dtype=torch.float32)
        self.log_vars = nn.Parameter(torch.log(weights))

    def forward(self, individual_losses: torch.Tensor) -> torch.Tensor:
        """Calculates the total weighted loss."""
        if not isinstance(individual_losses, torch.Tensor):
            individual_losses = torch.stack(individual_losses)
        assert individual_losses.dim() == 1 and individual_losses.size(0) == self.num_losses, \
            f"Input individual_losses must be a 1D tensor of size {self.num_losses}"
        total_loss = 0.0
        for i in range(self.num_losses):
            precision = torch.exp(-self.log_vars[i])
            weighted_loss_term = precision * individual_losses[i] + self.log_vars[i]
            total_loss += weighted_loss_term
        return total_loss

    def get_current_weights(self) -> Dict[str, float]:
        """Gets the current weights (calculated as exp(-log_var)) for monitoring."""
        with torch.no_grad():
            weights = torch.exp(-self.log_vars)
            return {f"weight_{i}": w.item() for i, w in enumerate(weights)}


class ClassWeightUpdater(nn.Module):
    """
    Dynamically adjusts class weights for CrossEntropyLoss based on a combined
    metric of class-wise Dice and IoU scores.
    Uses an Exponential Moving Average (EMA) to stabilize weight updates.
    """
    def __init__(self, num_classes: int, alpha: float = 0.9, epsilon: float = 1e-6):
        """
        Args:
            num_classes (int): Number of segmentation classes.
            alpha (float): Smoothing factor for EMA. Higher alpha means slower updates.
            epsilon (float): Small value to prevent division by zero.
        """
        super().__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.epsilon = epsilon
        # Buffer giờ sẽ lưu trữ điểm kết hợp (combined score) thay vì chỉ Dice
        self.register_buffer('ema_combined_scores', torch.ones(num_classes))

    def _calculate_per_class_metrics(self, logits: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Calculates both Dice and IoU scores for each class.
        
        Returns:
            A tuple containing (dice_scores, iou_scores) as tensors.
        """
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        
        dice_scores = []
        iou_scores = []
        
        for i in range(self.num_classes):
            pred_class = probs[:, i, :, :]
            target_class = targets_one_hot[:, i, :, :]
            
            # Tính toán các thành phần cơ bản
            intersection = torch.sum(pred_class * target_class)
            pred_sum = torch.sum(pred_class)
            target_sum = torch.sum(target_class)
            
            # Tính Dice Score
            dice = (2. * intersection + self.epsilon) / (pred_sum + target_sum + self.epsilon)
            dice_scores.append(dice)
            
            # Tính IoU Score (Jaccard Index)
            union = pred_sum + target_sum - intersection
            iou = (intersection + self.epsilon) / (union + self.epsilon)
            iou_scores.append(iou)
            
        return torch.stack(dice_scores), torch.stack(iou_scores)

    def update_and_get_weights(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Updates EMA and returns new class weights based on combined Dice and IoU performance.
        Args:
            logits (torch.Tensor): Raw output from the model (detached).
            targets (torch.Tensor): Ground truth labels.
        Returns:
            torch.Tensor: New weights for CrossEntropyLoss.
        """
        with torch.no_grad():
            # 1. Tính toán cả hai chỉ số
            current_dice, current_iou = self._calculate_per_class_metrics(logits, targets)
            
            # 2. Kết hợp điểm hiệu suất: lấy trung bình cộng
            current_combined_score = (current_dice + 2 * current_iou) / 3.0
            
            # 3. Cập nhật EMA bằng điểm kết hợp
            self.ema_combined_scores = self.alpha * self.ema_combined_scores + (1 - self.alpha) * current_combined_score
            
            # 4. Tính trọng số dựa trên nghịch đảo của điểm kết hợp đã được làm mượt
            inverse_scores = 1.0 / (self.ema_combined_scores + self.epsilon)
            
            # 5. Chuẩn hóa trọng số
            normalized_weights = self.num_classes * inverse_scores / torch.sum(inverse_scores)
            
            return normalized_weights


class CombinedLoss(nn.Module):
    """
    Combined loss với hai cấp độ điều chỉnh trọng số động:
    1. Trọng số động giữa các hàm loss khác nhau (CE, Dice, Physics, Smoothness, Anatomical).
    2. Trọng số class động bên trong CrossEntropyLoss.
    """
    def __init__(self, 
                 in_channels_maxwell=1024, 
                 num_classes=4, 
                 lambda_val=15.0, 
                 initial_loss_weights: Optional[List[float]] = None,
                 # THÊM MỚI: Truyền vào class_indices để AnatomicalRuleLoss sử dụng
                 class_indices_for_rules: Dict[str, int] = None):
        super().__init__()
        
        # --- Initialize Class Weight Updater ---
        self.class_weighter = ClassWeightUpdater(num_classes=num_classes) # Đã bỏ .to(DEVICE) để nhất quán
        
        # --- Initialize loss components ---
        # 1. Cross Entropy Loss
        self.ce = nn.CrossEntropyLoss()
        
        # 2. Adaptive t-vMF Dice Loss
        self.dl = Adaptive_tvmf_dice_loss(num_classes=num_classes, lambda_val=lambda_val)
        
        # 3. Physics Loss
        self.pl = PhysicsLoss(in_channels_maxwell)
        
        # 4. Smoothness Loss
        self.sl = SmoothnessLoss()

        # 5. THÊM MỚI: Anatomical Rule Loss
        if class_indices_for_rules is None:
            # Cung cấp giá trị mặc định hoặc báo lỗi nếu cần
             raise ValueError("`class_indices_for_rules` must be provided for AnatomicalRuleLoss.")
        self.arl = AnatomicalRuleLoss(class_indices=class_indices_for_rules)
        print("Initialized AnatomicalRuleLoss.")
        
        # --- CẬP NHẬT: Initialize Loss Function Weighter cho 5 thành phần loss ---
        # Số lượng loss giờ là 5
        self.loss_weighter = DynamicLossWeighter(num_losses=5, initial_weights=initial_loss_weights)
        
        print("Initialized CombinedLoss with 5 components (CE, Dice, Physics, Smoothness, Anatomical).")

    def forward(self, logits, targets, b1=None, all_es=None, feat_sm=None):
        """Forward pass với hai cấp độ điều chỉnh trọng số động."""
        
        # --- Step 1: Update and set dynamic class weights for CE ---
        new_class_weights = self.class_weighter.update_and_get_weights(logits.detach(), targets)
        self.ce.weight = new_class_weights.to(logits.device)
        
        # --- Step 2: Calculate individual loss components ---
        lce = self.ce(logits, targets.long())
        ldc = self.dl(logits, targets)

        lphy = torch.tensor(0.0, device=logits.device)
        if self.pl is not None and b1 is not None and all_es:
            try:
                e1, s1 = all_es[0]
                lphy = self.pl(b1, e1, s1)
            except (IndexError, TypeError):
                print("Warning: Physics loss skipped due to unexpected `all_es` format.")
        
        lsm = torch.tensor(0.0, device=logits.device)
        if feat_sm is not None:
            lsm = self.sl(feat_sm)

        # THÊM MỚI: Tính anatomical rule loss
        larl = self.arl(logits)

        # --- Step 3: CẬP NHẬT: Kết hợp 5 thành phần loss ---
        individual_losses = torch.stack([lce, ldc, lphy, lsm, larl])
        total_loss = self.loss_weighter(individual_losses)

        return total_loss

    def get_current_loss_weights(self) -> Dict[str, float]:
        """Helper để theo dõi trọng số giữa các hàm loss."""
        weights = self.loss_weighter.get_current_weights()
        # CẬP NHẬT: Thêm trọng số của loss mới
        return {
            "weight_CE": weights["weight_0"],
            "weight_Dice": weights["weight_1"],
            "weight_Physics": weights["weight_2"],
            "weight_Smoothness": weights["weight_3"],
            "weight_Anatomical": weights["weight_4"],
        }

    def get_current_class_weights(self) -> Dict[str, float]:
        """Helper để theo dõi trọng số class động cho CrossEntropyLoss."""
        with torch.no_grad():
            # Đảm bảo weight tồn tại và ở trên CPU để chuyển đổi
            if self.ce.weight is not None:
                current_weights = self.ce.weight.cpu()
                return {f"class_{i}_weight": w.item() for i, w in enumerate(current_weights)}
            return {} # Trả về rỗng nếu không có weight
        
    def get_kappa_values(self):
        """Get current κ values for monitoring"""
        if isinstance(self.dl, Adaptive_tvmf_dice_loss):
            return self.dl.get_adaptive_info()
        return {}

In [None]:
import os
import nibabel as nib
import numpy as np
from skimage.transform import resize
import sys
import configparser

# --- Data Loading ---
def load_acdc_data(directory, is_training=True, target_size=(256, 256), max_patients=None):
    imgs, msks = [], []
    patient_count = 0

    if not os.path.exists(directory):
        print(f"Error: Dataset directory not found at {directory}. "
              "Please ensure the ACDC dataset is added to your Kaggle notebook inputs.", file=sys.stderr)
        return np.array([]), None

    patient_folders = sorted([d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))])

    for patient_folder in patient_folders:
        if max_patients and patient_count >= max_patients:
            break

        patient_path = os.path.join(directory, patient_folder)
        info_cfg_path = os.path.join(patient_path, 'Info.cfg')

        ed_frame = -1
        es_frame = -1
        if os.path.exists(info_cfg_path):
            parser = configparser.ConfigParser()
            try:
                with open(info_cfg_path, 'r') as f:
                    config_string = '[DEFAULT]\n' + f.read()
                parser.read_string(config_string)

                ed_frame = int(parser['DEFAULT']['ED'])
                es_frame = int(parser['DEFAULT']['ES'])
            except Exception as e:
                print(f"Warning: Could not parse Info.cfg for {patient_folder} in {directory}: {e}. Skipping patient.", file=sys.stderr)
                continue
        else:
            print(f"Warning: Info.cfg not found for patient {patient_folder} in {directory}. Skipping patient.", file=sys.stderr)
            continue

        ed_img_filename = f'{patient_folder}_frame{ed_frame:02d}.nii'
        es_img_filename = f'{patient_folder}_frame{es_frame:02d}.nii'
        ed_mask_filename = f'{patient_folder}_frame{ed_frame:02d}_gt.nii'
        es_mask_filename = f'{patient_folder}_frame{es_frame:02d}_gt.nii'

        ed_img_path = os.path.join(patient_path, ed_img_filename)
        es_img_path = os.path.join(patient_path, es_img_filename)
        ed_mask_path = os.path.join(patient_path, ed_mask_filename)
        es_mask_path = os.path.join(patient_path, es_mask_filename)


        def _load_nifti_and_process_slices(img_fpath, mask_fpath, target_sz, is_train_flag_for_warning):
            current_images, current_masks = [], []
            try:
                if not os.path.exists(img_fpath):
                    print(f"Warning: Image file not found at {img_fpath}. Skipping this pair.", file=sys.stderr)
                    return None, None

                img_nifti = nib.load(img_fpath)
                img_data = img_nifti.get_fdata()

                mask_data = None
                # Always try to load mask if it exists
                if os.path.exists(mask_fpath):
                    mask_nifti = nib.load(mask_fpath)
                    mask_data = mask_nifti.get_fdata()
                elif is_train_flag_for_warning: # Only warn strongly if training set and mask is missing
                    print(f"Warning: Mask file not found for {img_fpath}. (Expected for training). Skipping this pair.", file=sys.stderr)
                    return None, None # Skip if mask is expected but not found

                if img_data.ndim == 3: # (width, height, depth)
                    for i in range(img_data.shape[2]):
                        slice_img = img_data[:, :, i]
                        resized_img = resize(
                            slice_img, target_sz,
                            order=1, preserve_range=True,
                            anti_aliasing=True, mode='reflect'
                        ).astype(np.float32)
                        current_images.append(np.expand_dims(resized_img, axis=-1))

                        if mask_data is not None:
                            slice_mask = mask_data[:, :, i]
                            resized_mask = resize(
                                slice_mask, target_sz,
                                order=0, preserve_range=True,
                                anti_aliasing=False, mode='reflect'
                            ).astype(np.uint8)
                            current_masks.append(resized_mask)
                elif img_data.ndim == 2: # Single 2D slice NIfTI file
                     resized_img = resize(
                        img_data, target_sz,
                        order=1, preserve_range=True,
                        anti_aliasing=True, mode='reflect'
                    ).astype(np.float32)
                     current_images.append(np.expand_dims(resized_img, axis=-1))

                     if mask_data is not None:
                         resized_mask = resize(
                            mask_data, target_sz,
                            order=0, preserve_range=True,
                            anti_aliasing=False, mode='reflect'
                        ).astype(np.uint8)
                         current_masks.append(resized_mask)
                else:
                    print(f"Warning: Unexpected image dimension ({img_data.ndim}) for {img_fpath}. Skipping.", file=sys.stderr)
                    return None, None

            except Exception as e:
                print(f"Error processing {img_fpath}: {e}", file=sys.stderr)
                return None, None
            return current_images, current_masks

        # Load ED frame and its mask
        ed_imgs, ed_msks = _load_nifti_and_process_slices(ed_img_path, ed_mask_path, target_size, is_training)
        if ed_imgs:
            imgs.extend(ed_imgs)
            # Extend masks if they were loaded, regardless of is_training flag
            if ed_msks: # Check if ed_msks is not None and not empty
                msks.extend(ed_msks)

        # Load ES frame and its mask
        es_imgs, es_msks = _load_nifti_and_process_slices(es_img_path, es_mask_path, target_size, is_training)
        if es_imgs:
            imgs.extend(es_imgs)
            # Extend masks if they were loaded, regardless of is_training flag
            if es_msks: # Check if es_msks is not None and not empty
                msks.extend(es_msks)

        patient_count += 1

    im_np = np.array(imgs, dtype=np.float32) if imgs else np.empty((0, target_size[0], target_size[1], 1), dtype=np.float32)
    
    # Only return msks_np if there are any masks collected
    msk_np = np.array(msks, dtype=np.uint8) if msks else None

    if msk_np is not None and msk_np.ndim == 4 and msk_np.shape[-1] == 1:
        msk_np = np.squeeze(msk_np, axis=-1)

    return im_np, msk_np

In [None]:
# --- Metrics ---
def evaluate_metrics(model, dataloader, device, num_classes=4):
    model.eval()
    tp = [0] * num_classes
    fp = [0] * num_classes
    fn = [0] * num_classes
    dice_s = [0.0] * num_classes
    iou_s = [0.0] * num_classes
    batches = 0

    with torch.no_grad():
        for imgs,tgts in dataloader:
            imgs,tgts = imgs.to(device),tgts.to(device)
            if imgs.size(0) == 0: continue
            logits,_ = model(imgs)
            preds = torch.argmax(F.softmax(logits,dim=1),dim=1); batches+=1
            for c in range(num_classes):
                pc_f,tc_f=(preds==c).float().view(-1),(tgts==c).float().view(-1); inter=(pc_f*tc_f).sum()
                dice_s[c]+=((2.*inter+1e-6)/(pc_f.sum()+tc_f.sum()+1e-6)).item()
                iou_s[c]+=((inter+1e-6)/(pc_f.sum()+tc_f.sum()-inter+1e-6)).item()
                tp[c]+=inter.item(); fp[c]+=(pc_f.sum()-inter).item(); fn[c]+=(tc_f.sum()-inter).item()
    metrics={'dice_scores':[],'iou':[],'precision':[],'recall':[],'f1_score':[]}
    if batches>0:
        for c in range(num_classes):
            metrics['dice_scores'].append(dice_s[c]/batches); metrics['iou'].append(iou_s[c]/batches)
            prec,rec = tp[c]/(tp[c]+fp[c]+1e-6), tp[c]/(tp[c]+fn[c]+1e-6)
            metrics['precision'].append(prec); metrics['recall'].append(rec)
            metrics['f1_score'].append(2*prec*rec/(prec+rec+1e-6) if (prec+rec > 0) else 0.0)
    else: 
        for _ in range(num_classes): [metrics[key].append(0.0) for key in metrics]
    return metrics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, Union
import os

class B1MapCommonCalculator:
    """
    Tính toán B1_map chung cho toàn bộ dataset ACDC
    Kết hợp cả phương pháp simulation và học từ đặc trưng ảnh
    """
    
    def __init__(self, img_size: int = 256, device: str = 'cuda'):
        self.img_size = img_size
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.common_b1_map = None
        self.dataset_statistics = {}
        
    def simulate_b1_map_physics_based(self, image_batch: torch.Tensor) -> torch.Tensor:
        """
        Mô phỏng B1_map dựa trên nguyên lý vật lý MRI và đặc trưng ảnh
        
        Args:
            image_batch: Tensor shape (B, C, H, W)
        Returns:
            b1_maps: Tensor shape (B, 1, H, W)
        """
        batch_size, channels, height, width = image_batch.shape
        device = image_batch.device
        
        # Tạo coordinate grids
        y_coords = torch.arange(height, dtype=torch.float32, device=device)
        x_coords = torch.arange(width, dtype=torch.float32, device=device)
        y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
        
        center_y, center_x = height // 2, width // 2
        
        # Distance từ center (RF coil thường đặt ở giữa)
        distance = torch.sqrt((x_grid - center_x)**2 + (y_grid - center_y)**2)
        max_distance = torch.sqrt(torch.tensor(center_x**2 + center_y**2, device=device))
        
        b1_maps = []
        
        for b in range(batch_size):
            # Lấy ảnh của batch hiện tại
            current_image = image_batch[b, 0]  # Shape: (H, W)
            
            # 1. B1 inhomogeneity pattern cơ bản (giảm từ center ra ngoài)
            b1_base = 1.0 - 0.25 * (distance / max_distance)
            
            # 2. Tissue-dependent variations
            # Mô có cường độ cao (như máu) có dielectric constant khác
            image_normalized = current_image / (torch.max(current_image) + 1e-8)
            tissue_factor = 0.85 + 0.3 * image_normalized
            
            # 3. Cardiac-specific adjustments
            # Tim có hình dạng và vị trí đặc biệt
            cardiac_factor = self._get_cardiac_b1_pattern(current_image, height, width, device)
            
            # 4. RF coil loading effects
            # Tải RF phụ thuộc vào phân bố mô
            loading_effect = self._calculate_rf_loading(current_image, distance, device)
            
            # 5. Kết hợp các yếu tố
            b1_map = b1_base * tissue_factor * cardiac_factor * loading_effect
            
            # 6. Thêm realistic noise và constraints
            noise = torch.randn_like(b1_map, device=device) * 0.03
            b1_map = b1_map + noise
            
            # Clip về range thực tế của B1 field (0.4 - 1.3)
            b1_map = torch.clamp(b1_map, 0.4, 1.3)
            
            b1_maps.append(b1_map.unsqueeze(0))  # Add channel dimension
        
        return torch.stack(b1_maps, dim=0)  # Shape: (B, 1, H, W)
    
    def _get_cardiac_b1_pattern(self, image: torch.Tensor, height: int, width: int, device: torch.device) -> torch.Tensor:
        """
        Tạo B1 pattern đặc biệt cho cardiac imaging
        """
        # Cardiac region thường ở center-left của ảnh
        cardiac_center_y = height // 2
        cardiac_center_x = int(width * 0.4)  # Slightly left of center
        
        y_coords = torch.arange(height, dtype=torch.float32, device=device)
        x_coords = torch.arange(width, dtype=torch.float32, device=device)
        y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
        
        # Distance từ cardiac center
        cardiac_distance = torch.sqrt((x_grid - cardiac_center_x)**2 + (y_grid - cardiac_center_y)**2)
        
        # B1 field mạnh hơn ở cardiac region
        cardiac_enhancement = 1.0 + 0.1 * torch.exp(-cardiac_distance / (width * 0.15))
        
        # Modulate bởi image intensity (cardiac structures có contrast cao)
        image_weight = image / (torch.max(image) + 1e-8)
        cardiac_factor = cardiac_enhancement * (0.9 + 0.2 * image_weight)
        
        return cardiac_factor
    
    def _calculate_rf_loading(self, image: torch.Tensor, distance: torch.Tensor, device: torch.device) -> torch.Tensor:
        """
        Tính toán RF loading effect dựa trên phân bố mô
        """
        # RF loading tăng khi có nhiều mô (high intensity regions)
        tissue_density = image / (torch.max(image) + 1e-8)
        
        # Loading effect mạnh hơn ở center (gần RF coil)
        loading_base = 1.0 - 0.1 * (distance / torch.max(distance))
        
        # Kết hợp với tissue density
        loading_effect = loading_base * (0.95 + 0.1 * tissue_density)
        
        return loading_effect
    
    def calculate_dataset_common_b1_map(self, 
                                      all_images: torch.Tensor,
                                      use_weighted_average: bool = True,
                                      save_path: Optional[str] = None) -> torch.Tensor:
        print("Calculating common B1 map for entire ACDC dataset...")
        
        num_images = all_images.shape[0]
        batch_size = min(16, num_images)  # Process in batches để tránh memory overflow
        
        all_b1_maps = []
        image_statistics = []
        
        # Process images in batches
        for i in range(0, num_images, batch_size):
            end_idx = min(i + batch_size, num_images)
            batch_images = all_images[i:end_idx].to(self.device)
            
            # Generate B1 maps for current batch
            batch_b1_maps = self.simulate_b1_map_physics_based(batch_images)
            all_b1_maps.append(batch_b1_maps.cpu())
            
            # Collect statistics
            for j in range(batch_images.shape[0]):
                img_stats = {
                    'mean_intensity': torch.mean(batch_images[j]).item(),
                    'std_intensity': torch.std(batch_images[j]).item(),
                    'max_intensity': torch.max(batch_images[j]).item()
                }
                image_statistics.append(img_stats)
            
            if (i // batch_size + 1) % 10 == 0:
                print(f"Processed {i + batch_size}/{num_images} images...")
        
        # Concatenate all B1 maps
        all_b1_maps = torch.cat(all_b1_maps, dim=0)  # Shape: (N, 1, H, W)
        
        if use_weighted_average:
            common_b1_map = self._calculate_weighted_average(all_b1_maps, image_statistics)
        else:
            common_b1_map = torch.mean(all_b1_maps, dim=0, keepdim=True)
        
        # Store results
        self.common_b1_map = common_b1_map
        self.dataset_statistics = {
            'num_images': num_images,
            'b1_range': (float(torch.min(common_b1_map)), float(torch.max(common_b1_map))),
            'b1_mean': float(torch.mean(common_b1_map)),
            'b1_std': float(torch.std(common_b1_map)),
            'image_stats': {
                'mean_intensity_avg': np.mean([s['mean_intensity'] for s in image_statistics]),
                'std_intensity_avg': np.mean([s['std_intensity'] for s in image_statistics])
            }
        }
        
        print(f"Common B1 map calculated successfully!")
        print(f"  - B1 range: {self.dataset_statistics['b1_range'][0]:.3f} - {self.dataset_statistics['b1_range'][1]:.3f}")
        print(f"  - B1 mean: {self.dataset_statistics['b1_mean']:.3f}")
        print(f"  - B1 std: {self.dataset_statistics['b1_std']:.3f}")
        
        # Save if path provided
        if save_path:
            self.save_common_b1_map(save_path)
        
        return common_b1_map
    
    def _calculate_weighted_average(self, all_b1_maps: torch.Tensor, image_stats: list) -> torch.Tensor:
        """
        Tính weighted average, ưu tiên các vùng center và ảnh có contrast cao
        """
        _, _, height, width = all_b1_maps.shape
        
        # Create spatial weights (higher weight for center regions)
        y_coords = torch.arange(height, dtype=torch.float32)
        x_coords = torch.arange(width, dtype=torch.float32)
        y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
        
        center_y, center_x = height // 2, width // 2
        spatial_weights = torch.exp(-((x_grid - center_x)**2 + (y_grid - center_y)**2) / (2 * (min(height, width) / 4)**2))
        spatial_weights = spatial_weights.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, W)
        
        # Create image-wise weights based on contrast
        image_weights = []
        for stats in image_stats:
            # Higher weight for images with good contrast
            contrast_score = stats['std_intensity'] / (stats['mean_intensity'] + 1e-8)
            weight = min(max(contrast_score, 0.5), 2.0)  # Clip between 0.5 and 2.0
            image_weights.append(weight)
        
        image_weights = torch.tensor(image_weights).view(-1, 1, 1, 1)  # Shape: (N, 1, 1, 1)
        
        # Apply weights
        weighted_b1_maps = all_b1_maps * spatial_weights * image_weights
        total_weights = spatial_weights * image_weights
        
        # Calculate weighted average
        common_b1_map = torch.sum(weighted_b1_maps, dim=0, keepdim=True) / torch.sum(total_weights, dim=0, keepdim=True)
        
        return common_b1_map
    
    def get_b1_map_for_batch(self, batch_images: torch.Tensor) -> torch.Tensor:
        """
        Lấy B1_map cho một batch ảnh
        Args:
            batch_images: Tensor shape (B, C, H, W)
        Returns:
            b1_maps: Tensor shape (B, 1, H, W)
        """
        if self.common_b1_map is not None:
            # Sử dụng common B1 map, broadcast cho toàn batch
            batch_size = batch_images.shape[0]
            return self.common_b1_map.expand(batch_size, -1, -1, -1).to(batch_images.device)
        else:
            # Tính B1 map riêng cho batch này
            return self.simulate_b1_map_physics_based(batch_images)
    
    def save_common_b1_map(self, save_path: str):
        """Lưu common B1 map và statistics"""
        if self.common_b1_map is not None:
            save_dict = {
                'common_b1_map': self.common_b1_map,
                'dataset_statistics': self.dataset_statistics,
                'img_size': self.img_size
            }
            torch.save(save_dict, save_path)
            # print(f"Common B1 map saved to: {save_path}")
    
    def load_common_b1_map(self, load_path: str):
        """Load common B1 map từ file đã lưu"""
        if os.path.exists(load_path):
            save_dict = torch.load(load_path, map_location=self.device)
            self.common_b1_map = save_dict['common_b1_map']
            self.dataset_statistics = save_dict['dataset_statistics']
            self.img_size = save_dict.get('img_size', 256)
            # print(f"Common B1 map loaded from: {load_path}")
            # print(f"  - B1 range: {self.dataset_statistics['b1_range'][0]:.3f} - {self.dataset_statistics['b1_range'][1]:.3f}")
        # else:
            # print(f"File not found: {load_path}")

def integrate_b1_map_into_training(X_train_tensor: torch.Tensor, 
                                 X_val_tensor: torch.Tensor, 
                                 X_test_tensor: torch.Tensor,
                                 img_size: int = 256,
                                 device: str = 'cuda') -> B1MapCommonCalculator:
    
    # print("=== Integrating B1 Map Calculator ===")
    
    # Initialize calculator
    b1_calculator = B1MapCommonCalculator(img_size=img_size, device=device)
    
    # Try to load existing common B1 map
    save_path = "acdc_common_b1_map.pth"
    b1_calculator.load_common_b1_map(save_path)
    
    # If not loaded, calculate new one
    if b1_calculator.common_b1_map is None:
        # Combine all images for calculating common B1 map
        all_images = torch.cat([X_train_tensor, X_val_tensor, X_test_tensor], dim=0)
        
        # Calculate common B1 map
        common_b1_map = b1_calculator.calculate_dataset_common_b1_map(
            all_images, 
            use_weighted_average=True,
            save_path=save_path
        )
        
        # print(f"New common B1 map calculated and saved!")
    # else:
    #     print(f"Using existing common B1 map!")
    return b1_calculator

# Hàm thay thế cho việc sử dụng trong training loop
def get_b1_map_for_training(images: torch.Tensor, 
                          b1_calculator: B1MapCommonCalculator) -> torch.Tensor:
    return b1_calculator.get_b1_map_for_batch(images)

In [None]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import numpy as np
from itertools import chain

# --- Main Execution (Centralized Training) ---
if __name__ == "__main__":
    print(f"Device: {DEVICE}")

    base_dataset_root = '/kaggle/input/automated-cardiac-diagnosis-challenge-miccai17/database'
    
    # Cập nhật đường dẫn chính xác cho thư mục training và testing
    train_data_path = os.path.join(base_dataset_root, 'training') # Cập nhật lại đường dẫn này
    test_data_path = os.path.join(base_dataset_root, 'testing')   # Cập nhật lại đường dẫn này

    # Kiểm tra xem đường dẫn có tồn tại không
    if not os.path.exists(train_data_path) or not os.listdir(train_data_path):
        print(f"Path '{train_data_path}' not found or empty. Using DUMMY data.", file=sys.stderr)
        # Tạo dữ liệu dummy nếu không tìm thấy dữ liệu thật
        X_train_tensor = torch.randn(100, 1, IMG_SIZE, IMG_SIZE) # 100 mẫu huấn luyện
        y_train_tensor = torch.randint(0, NUM_CLASSES, (100, IMG_SIZE, IMG_SIZE))
        X_val_tensor = torch.randn(20, 1, IMG_SIZE, IMG_SIZE) # 20 mẫu validation
        y_val_tensor = torch.randint(0, NUM_CLASSES, (20, IMG_SIZE, IMG_SIZE))
        X_test_tensor = torch.randn(30, 1, IMG_SIZE, IMG_SIZE) # 30 mẫu test
        y_test_tensor = torch.randint(0, NUM_CLASSES, (30, IMG_SIZE, IMG_SIZE)) # Dùng cho dummy, nếu không có mask thì bỏ qua

    else:
        # Tải toàn bộ dữ liệu huấn luyện
        print(f"Loading training data from: {train_data_path} (all patients)...")
        all_train_images_np, all_train_masks_np = load_acdc_data(
            train_data_path, # Truyền đường dẫn đến thư mục 'training'
            is_training=True,
            target_size=(IMG_SIZE, IMG_SIZE),
            max_patients=None # Load tất cả 100 bệnh nhân
        )
        print(f"Loaded {all_train_images_np.shape[0]} training images.")
        print(f"Shape of training images: {all_train_images_np.shape}")
        if all_train_masks_np is not None:
            print(f"Shape of training masks: {all_train_masks_np.shape}")
        else:
            raise ValueError("Training masks are None. They are required for training.")

        # Tải toàn bộ dữ liệu kiểm tra
        print(f"\nLoading testing data from: {test_data_path} (all patients)...")
        all_test_images_np, all_test_masks_np = load_acdc_data( # Nên giữ all_test_masks_np để kiểm tra
            test_data_path, 
            is_training=False, # Mask không bắt buộc cho tập kiểm tra
            target_size=(IMG_SIZE, IMG_SIZE),
            max_patients=None # Load tất cả 50 bệnh nhân
        )
        print(f"Loaded {all_test_images_np.shape[0]} testing images.")
        print(f"Shape of testing images: {all_test_images_np.shape}")
        if all_test_masks_np is None:
            print("No masks loaded for testing set (as expected if not present in source).")
        else:
            print(f"Shape of testing masks (if loaded): {all_test_masks_np.shape}")

        if all_train_images_np.size == 0:
            raise ValueError("Training data is empty after loading. Check data path and content.")
            
        if all_test_images_np.size == 0:
            print("Warning: Test data is empty after loading. Evaluation might be affected.", file=sys.stderr)

        # Normalize ảnh về khoảng [0, 1]
        # Chú ý: Cần xử lý giá trị NaN/inf nếu có do max = 0
        if np.max(all_train_images_np) > 0:
            all_train_images_np = all_train_images_np / np.max(all_train_images_np)
        else:
            print("Warning: Max value of training images is 0, no normalization applied.", file=sys.stderr)
        
        if np.max(all_test_images_np) > 0:
            all_test_images_np = all_test_images_np / np.max(all_test_images_np)
        else:
            print("Warning: Max value of testing images is 0, no normalization applied.", file=sys.stderr)

        # Chuyển đổi dữ liệu kiểm tra sang Tensor và tạo DataLoader
        X_test_tensor = torch.tensor(all_test_images_np).permute(0, 3, 1, 2).float()

        # Xử lý y_test_tensor: chỉ tạo nếu mask có sẵn, nếu không thì dùng nhãn giả hoặc không dùng
        if all_test_masks_np is not None:
            y_test_tensor = torch.tensor(all_test_masks_np).long()
            test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
            print("Test DataLoader will include masks.")
        else:
            # Nếu không có mask, tạo DataLoader chỉ với ảnh
            test_dataset = TensorDataset(X_test_tensor) # Chỉ có ảnh
            print("Test DataLoader will NOT include masks (as they are not available for testing set).")
        
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True if DEVICE.type == 'cuda' else False)
        print(f"Test samples: {len(test_dataset)}")

        # Chia train/validation từ toàn bộ dữ liệu huấn luyện đã tải
        # Đảm bảo cả X và y đều không rỗng trước khi chia
        if all_train_images_np.shape[0] > 0 and all_train_masks_np.shape[0] > 0:
            X_train_np, X_val_np, y_train_np, y_val_np = train_test_split(
                all_train_images_np, all_train_masks_np, test_size=0.2, random_state=42 # 20% cho validation
            )
        else:
            raise ValueError("Not enough training data to perform train/val split.")

        # Chuyển đổi dữ liệu huấn luyện và validation sang Tensor và tạo DataLoader
        X_train_tensor = torch.tensor(X_train_np).permute(0, 3, 1, 2).float()
        y_train_tensor = torch.tensor(y_train_np).long()
        X_val_tensor = torch.tensor(X_val_np).permute(0, 3, 1, 2).float()
        y_val_tensor = torch.tensor(y_val_np).long()

    # Kiểm tra kích thước của các tập dữ liệu sau khi chia
    if len(X_train_tensor) == 0: raise ValueError("No training samples after split.")
    if len(X_val_tensor) == 0: print("Warning: Validation set is empty after split.", file=sys.stderr)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True if DEVICE.type == 'cuda' else False)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True if DEVICE.type == 'cuda' else False)

    print(f"\nTraining samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
    print("Data loaded and prepared for centralized training.")

    # --- Initialize Model, Criterion, Optimizer ---
    # Đảm bảo các class này đã được import hoặc định nghĩa
    model = RobustMedVFL_UNet(n_channels=1, n_classes=NUM_CLASSES).to(DEVICE)
    # criterion = CombinedLoss(num_classes=NUM_CLASSES, in_channels_maxwell=1024).to(DEVICE)
    my_class_indices = {'RV': 1, 'MYO': 2, 'LV': 3}

    # Khởi tạo loss function
    criterion = CombinedLoss(
        in_channels_maxwell=1024,
        num_classes=4,
        lambda_val=15.0,
        initial_loss_weights=[0.3, 0.5, 0.5, 1.0, 0.5], # Khởi tạo trọng số cho 5 loss
        class_indices_for_rules=my_class_indices
    ).to(DEVICE)
    optimizer = torch.optim.Adam(
    chain(model.parameters(), criterion.parameters()), 
    lr=LEARNING_RATE
    )
    
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5) # Tùy chọn

    # --- Centralized Training Loop ---
    best_val_metric = 0.0 # Hoặc float('inf') nếu loss là metric chính

    for epoch in range(NUM_EPOCHS_CENTRALIZED):
        print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS_CENTRALIZED} ---")
        
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        num_train_batches = 0
        
        for images, targets in train_dataloader:
            images, targets = images.to(DEVICE), targets.to(DEVICE)
            
            images_noisy = quantum_noise_injection(images) # Tùy chọn áp dụng noise
            
            optimizer.zero_grad()
            logits, all_eps_sigma_tuples = model(images_noisy)

            b1_calculator = integrate_b1_map_into_training(
            X_train_tensor, X_val_tensor, X_test_tensor,
            img_size=IMG_SIZE, device=DEVICE
            )
            b1_map = get_b1_map_for_training(images, b1_calculator)
            loss = criterion(logits, targets, b1_map, all_eps_sigma_tuples) #, features_for_smoothness=None)
            
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
            num_train_batches += 1
            
        avg_train_loss = epoch_train_loss / num_train_batches if num_train_batches > 0 else 0
        print(f"   Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}")
        
        # Validation phase
        if val_dataloader.dataset and len(val_dataloader.dataset) > 0:
            print("   Evaluating on validation set...")
            val_metrics = evaluate_metrics(model, val_dataloader, DEVICE, NUM_CLASSES)
            # Sử dụng Dice score của class foreground trung bình làm metric chính để so sánh
            # Lấy các giá trị foreground (class từ 1 trở đi)
            fg_dice = val_metrics['dice_scores'][1:] if NUM_CLASSES > 1 else [val_metrics['dice_scores'][0]]
            fg_iou = val_metrics['iou'][1:] if NUM_CLASSES > 1 else [val_metrics['iou'][0]]
            fg_precision = val_metrics['precision'][1:] if NUM_CLASSES > 1 else [val_metrics['precision'][0]]
            fg_recall = val_metrics['recall'][1:] if NUM_CLASSES > 1 else [val_metrics['recall'][0]]
            fg_f1 = val_metrics['f1_score'][1:] if NUM_CLASSES > 1 else [val_metrics['f1_score'][0]]
            
            avg_fg_dice = np.mean(fg_dice)
            avg_fg_iou = np.mean(fg_iou)
            avg_fg_precision = np.mean(fg_precision)
            avg_fg_recall = np.mean(fg_recall)
            avg_fg_f1 = np.mean(fg_f1)
            
            print(f"   Epoch {epoch+1} - Validation (Avg Foreground): "
                  f"Dice: {avg_fg_dice:.4f}; IoU: {avg_fg_iou:.4f}; "
                  f"Precision: {avg_fg_precision:.4f}; Recall: {avg_fg_recall:.4f}; F1-score: {avg_fg_f1:.4f}")
            for c_idx in range(NUM_CLASSES):
                print(f"     Class {c_idx}: Dice: {val_metrics['dice_scores'][c_idx]:.4f}; "
                      f"IoU: {val_metrics['iou'][c_idx]:.4f}; "
                      f"Precision: {val_metrics['precision'][c_idx]:.4f}; "
                      f"Recall: {val_metrics['recall'][c_idx]:.4f}; "
                      f"F1-score: {val_metrics['f1_score'][c_idx]:.4f}")

            # Tùy chọn: Lưu model tốt nhất dựa trên val_metric
            if avg_fg_dice > best_val_metric:
                best_val_metric = avg_fg_dice
                # torch.save(model.state_dict(), "best_centralized_model.pth")
                # print(f"     New best model saved with Val Dice: {best_val_metric:.4f}")
            
            # if scheduler: scheduler.step(avg_val_loss_or_metric) # Nếu dùng scheduler
        else:
            print("   Validation dataset is empty. Skipping validation.")

        current_loss_weights = criterion.get_current_loss_weights() # Giả sử object loss của bạn tên là `criterion`
        current_class_weights = criterion.get_current_class_weights()
    
        print(f"\n   Epoch {epoch+1} - Learned Loss Weights:")
        for name, weight in current_loss_weights.items():
            print(f"     - {name}: {weight:.4f}")
        
        print(f"   Epoch {epoch+1} - Dynamic Class Weights (for CE):")
        class_weights_str = " | ".join([f"Class {i}: {w:.4f}" for i, w in enumerate(current_class_weights.values())])
        print(f"     - {class_weights_str}")
        
        print("-" * 60)

    print("\n--- Centralized Training Finished ---")

# Trực quan kết quả sau huấn luyện

In [None]:
import random
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Các định nghĩa cho việc trực quan hóa (giữ nguyên)
ACDC_CLASS_MAP = {
    0: "Background",
    1: "Right Ventricle",
    2: "Myocardium",
    3: "Left Ventricle"
}
ACDC_COLOR_MAP = {
    0: 'black',
    1: '#FF0000',
    2: '#00FF00',
    3: '#0000FF'
}

def visualize_final_results(model, images_np, masks_np, num_classes, num_samples=3, device=None):
    """
    Trực quan hóa và so sánh kết quả của mô hình trên các ảnh ngẫu nhiên.

    Args:
        model (torch.nn.Module): Mô hình PyTorch đã được huấn luyện.
        images_np (np.ndarray): Mảng numpy chứa các ảnh.
        masks_np (np.ndarray): Mảng numpy chứa các mặt nạ ground truth.
        num_classes (int): Số lượng lớp của bài toán segmentation. <--- THAM SỐ MỚI
        num_samples (int): Số lượng mẫu ngẫu nhiên để hiển thị.
        device (torch.device, optional): Thiết bị để chạy mô hình.
    """
    if not device:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    model.eval()

    total_images = images_np.shape[0]
    if total_images == 0:
        print("Không có ảnh nào để trực quan hóa.")
        return
        
    sample_indices = random.sample(range(total_images), min(num_samples, total_images))

    for idx in sample_indices:
        image_np_single = images_np[idx]
        image_tensor = torch.from_numpy(image_np_single).permute(2, 0, 1).unsqueeze(0)
        image_tensor = image_tensor.to(device, dtype=torch.float32)

        with torch.no_grad():
            output, _ = model(image_tensor)
            pred_mask_tensor = torch.argmax(output, dim=1)
        
        pred_mask_np = pred_mask_tensor.cpu().squeeze().numpy()

        has_gt_mask = masks_np is not None and idx < len(masks_np)
        num_subplots = 3 if has_gt_mask else 2
        fig, axes = plt.subplots(1, num_subplots, figsize=(13 * num_subplots / 2, 7))
        fig.suptitle(f'Kết quả cho ảnh số {idx}', fontsize=16)
        
        # SỬA LỖI Ở ĐÂY: Dùng `num_classes` thay vì `model.n_classes`
        colors = [ACDC_COLOR_MAP.get(i, 'black') for i in range(num_classes)]
        cmap = mcolors.ListedColormap(colors)
        
        axes[0].imshow(image_np_single.squeeze(), cmap='gray')
        axes[0].set_title('Ảnh MRI Gốc')
        axes[0].axis('off')

        ax_pred = axes[1]
        ax_pred.imshow(image_np_single.squeeze(), cmap='gray')
        pred_masked_display = np.ma.masked_where(pred_mask_np == 0, pred_mask_np)
        ax_pred.imshow(pred_masked_display, cmap=cmap, alpha=0.6, vmin=0, vmax=len(colors)-1)
        ax_pred.set_title('Dự đoán của mô hình')
        ax_pred.axis('off')

        if has_gt_mask:
            gt_mask_np = masks_np[idx]
            ax_gt = axes[2]
            ax_gt.imshow(image_np_single.squeeze(), cmap='gray')
            gt_masked_display = np.ma.masked_where(gt_mask_np == 0, gt_mask_np)
            ax_gt.imshow(gt_masked_display, cmap=cmap, alpha=0.6, vmin=0, vmax=len(colors)-1)
            ax_gt.set_title('Mặt nạ Ground Truth')
            ax_gt.axis('off')

        legend_elements = [
            plt.Rectangle((0, 0), 1, 1, color=ACDC_COLOR_MAP[i], label=ACDC_CLASS_MAP[i])
            for i in range(1, num_classes)
        ]
        fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 0.02))

        plt.tight_layout(rect=[0, 0.05, 1, 0.95])
        plt.show()


# --- TRỰC QUAN HÓA KẾT QUẢ CUỐI CÙNG ---
print("\n--- Visualizing Final Model Predictions on Test Set ---")

# Gọi hàm trực quan hóa
# Hàm này sẽ sử dụng model sau khi đã huấn luyện xong và bộ dữ liệu test đã tải
# all_test_images_np và all_test_masks_np đã được định nghĩa ở đầu mã của bạn
visualize_final_results(
    model=model,
    images_np=all_test_images_np,
    masks_np=all_test_masks_np, # Truyền cả mặt nạ (nếu có)
    num_classes=NUM_CLASSES,  # <--- THÊM DÒNG NÀY
    num_samples=3,              # Số lượng ảnh ngẫu nhiên muốn xem
    device=DEVICE
)

# Trực quan kết quả trên tập test set


In [None]:
import numpy as np

def run_and_print_test_evaluation(model, test_dataloader, device, num_classes):
    """
    Args:
        model (torch.nn.Module): Mô hình PyTorch đã được huấn luyện.
        test_dataloader (torch.utils.data.DataLoader): DataLoader cho tập test.
        device (torch.device): Thiết bị để chạy mô hình ('cuda' hoặc 'cpu').
        num_classes (int): Số lượng lớp của bài toán.
    """
    # Kiểm tra xem test_dataloader có hợp lệ không
    if test_dataloader and test_dataloader.dataset and len(test_dataloader.dataset) > 0:
        print("\n--- Evaluating on Test Set ---")
        
        # Gọi hàm tính toán metrics (hàm này cần được định nghĩa ở nơi khác)
        test_metrics = evaluate_metrics(model, test_dataloader, device, num_classes)

        # Lấy các chỉ số cho các lớp foreground (từ lớp 1 trở đi)
        if num_classes > 1:
            fg_dice = test_metrics['dice_scores'][1:]
            fg_iou = test_metrics['iou'][1:]
            fg_precision = test_metrics['precision'][1:]
            fg_recall = test_metrics['recall'][1:]
            fg_f1 = test_metrics['f1_score'][1:]
        else: # Trường hợp chỉ có 1 lớp
            fg_dice = [test_metrics['dice_scores'][0]]
            fg_iou = [test_metrics['iou'][0]]
            fg_precision = [test_metrics['precision'][0]]
            fg_recall = [test_metrics['recall'][0]]
            fg_f1 = [test_metrics['f1_score'][0]]
        
        # In kết quả trung bình của các lớp foreground
        print(f"  Test (Avg Foreground): "
              f"Dice: {np.mean(fg_dice):.4f}; IoU: {np.mean(fg_iou):.4f}; "
              f"Precision: {np.mean(fg_precision):.4f}; Recall: {np.mean(fg_recall):.4f}; "
              f"F1-score: {np.mean(fg_f1):.4f}")
        
        # In kết quả chi tiết cho từng lớp
        for c_idx in range(num_classes):
            print(f"    Class {c_idx} ({ACDC_CLASS_MAP.get(c_idx, 'N/A')}): "
                  f"Dice: {test_metrics['dice_scores'][c_idx]:.4f}; "
                  f"IoU: {test_metrics['iou'][c_idx]:.4f}; "
                  f"Precision: {test_metrics['precision'][c_idx]:.4f}; "
                  f"Recall: {test_metrics['recall'][c_idx]:.4f}; "
                  f"F1-score: {test_metrics['f1_score'][c_idx]:.4f}")
    else:
        print("\nTest dataset not available or empty. Skipping test evaluation.")

# --- Evaluate on Test Set ---

# 1. Chạy đánh giá và in các chỉ số metrics
run_and_print_test_evaluation(
    model=model,
    test_dataloader=test_dataloader,
    device=DEVICE,
    num_classes=NUM_CLASSES
)

# 2. Trực quan hóa kết quả trên một vài ảnh ngẫu nhiên từ tập test
print("\n--- Visualizing Predictions on Test Set Samples ---")
# Sử dụng các biến numpy đã tải từ trước (all_test_images_np, all_test_masks_np)
visualize_final_results(
    model=model,
    images_np=all_test_images_np,
    masks_np=all_test_masks_np,
    num_classes=NUM_CLASSES,
    num_samples=50, # Số lượng ảnh muốn xem
    device=DEVICE
)