In [None]:
# -*- coding: utf-8 -*-
"""CaTS_Framework_Pytorch.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1... (Replace with actual Colab link if desired)

# Causality-Driven Teacher-Student (CaTS) Framework Implementation

This notebook implements the CaTS framework based on the paper "Consistency-Regularized Causal Teacher-Student Learning with Diffusion Synthesis for PCOS Diagnosis". It integrates causal feature disentanglement within a semi-supervised Mean Teacher architecture.
"""

# %% [markdown]
# ## 1. Setup and Imports

# %%
!pip install torch torchvision torchaudio tqdm pandas scikit-learn tensorboard Pillow -q

# %%
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, Sampler
from PIL import Image
import os
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np
import random
import copy
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import itertools
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import math
import warnings

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Seed everything for reproducibility
def seed_everything(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    # Set PyTorch deterministic operations for cudnn backend
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# %% [markdown]
# ## 2. Configuration

# %%
# --- Hyperparameters ---
config = {
    "image_size": 256,
    "batch_size": 16, # Total batch size
    "labeled_bs_ratio": 0.5, # Ratio of labeled examples per batch (paper uses 20% labeled data overall, this controls batch composition)
    "num_classes": 2, # PCOS Positive/Negative
    "base_lr": 1e-3, # Paper: max 1e-3
    "weight_decay": 1e-2, # Paper: 1e-2
    "epochs": 100, # Number of training epochs
    "ema_decay": 0.99, # For Teacher model update
    "consistency_lambda": 1.0, # Weight for consistency loss (L_CR)
    "confound_lambda": 1.0, # Weight for confounding suppression loss (L_con)
    "backdoor_lambda": 1.0, # Weight for backdoor adjustment loss (L_b)
    "n_transformer_layers": 6, # Paper: nL=6
    "n_causal_queries": 8, # Paper: nQ=8
    "transformer_embed_dim": 2048, # Match ResNet output channels
    "transformer_nhead": 8, # Standard choice
    "transformer_ff_dim": 2048, # Feed-forward dim
    "memory_bank_size": 1024, # Size of the memory bank for L_b
    "train_test_split_ratio": 0.8,
    "labeled_unlabeled_split_ratio": 0.2, # Paper: 20% labeled in training set
    "seed": 1337,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "data_root": "/content/PCOSGen-train/images", # Adjust if necessary
    "output_dir": "./cats_output"
}

config["labeled_bs"] = int(config["batch_size"] * config["labeled_bs_ratio"]) # Actual labeled samples per batch

print(f"Using device: {config['device']}")
os.makedirs(config["output_dir"], exist_ok=True)
os.makedirs(os.path.join(config["output_dir"], "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(config["output_dir"], "logs"), exist_ok=True)

# Tensorboard Writer
writer = SummaryWriter(log_dir=os.path.join(config["output_dir"], "logs"))

# %% [markdown]
# ## 3. Data Loading and Preparation

# %% [markdown]
# #### Unzip Data (Run Once)

# %%
# Mount Google Drive if data is there
# from google.colab import drive
# drive.mount('/content/drive')
# !unzip /content/drive/MyDrive/PCOSGen-train.zip -d /content/

# If data is uploaded directly or already present, adjust path in config["data_root"]

# Create CSV file from folder structure (Run Once)
if not os.path.exists('/content/pcos_dataset.csv'):
    print("Creating dataset CSV...")
    image_paths = []
    labels = []
    class_map = {'PCOS_positive': 1, 'PCOS_negative': 0} # Example mapping
    # Adjust class folder names based on your actual unzipped structure
    positive_folder = os.path.join(config["data_root"], '../PCOSGen-train/images') # Adjust if structure is different
    negative_folder = os.path.join(config["data_root"], '../PCOSGen-train/images') # Adjust if structure is different

    # Find relevant image files based on naming convention or structure
    all_files = [f for f in os.listdir(positive_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    # Heuristic based on common naming patterns (adjust if needed)
    for filename in all_files:
        full_path = os.path.join(positive_folder, filename)
        if 'pco' in filename.lower() or 'polycystic' in filename.lower() or 'infected' in filename.lower():
             # Check if the file exists before adding
             if os.path.exists(full_path):
                 image_paths.append(full_path)
                 labels.append(1) # PCOS Positive/Infected
        elif 'normal' in filename.lower() or 'notinfected' in filename.lower():
             # Check if the file exists before adding
            if os.path.exists(full_path):
                image_paths.append(full_path)
                labels.append(0) # PCOS Negative/Not Infected

    if not image_paths:
         raise FileNotFoundError("No images found. Check data_root and folder structure.")


    df = pd.DataFrame({'image_path': image_paths, 'label': labels})
    df.to_csv('/content/pcos_dataset.csv', index=False)
    print("Dataset CSV created.")
else:
    print("Dataset CSV already exists.")

# %% [markdown]
# #### Dataset Class and Augmentations

# %%
# Define augmentations
# Weak Augmentation (for Teacher)
weak_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(config["image_size"], padding=int(config["image_size"]*0.125), padding_mode='reflect'),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Strong Augmentation (for Student on unlabeled data) - using RandAugment
strong_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.RandAugment(num_ops=2, magnitude=10), # Apply RandAugment
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Standard transform for validation/test
val_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


class PCOSImageDataset(Dataset):
    def __init__(self, csv_file, weak_transform=None, strong_transform=None, val_transform=None, mode='train'):
        self.dataframe = pd.read_csv(csv_file)
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.val_transform = val_transform
        self.mode = mode

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        label = torch.tensor(self.dataframe.iloc[idx]['label'], dtype=torch.long)

        try:
             # Check if the image path exists
            if not os.path.exists(img_path):
                 print(f"Warning: Image path not found {img_path}. Skipping.")
                 # Return placeholders or handle appropriately
                 placeholder_img = torch.zeros((3, config["image_size"], config["image_size"]))
                 # Depending on mode, return expected tuple structure
                 if self.mode == 'train':
                    return placeholder_img, placeholder_img, torch.tensor(-1, dtype=torch.long) # Invalid label
                 else:
                    return placeholder_img, torch.tensor(-1, dtype=torch.long)

            image = Image.open(img_path).convert('RGB')

            if self.mode == 'train':
                img_weak = self.weak_transform(image)
                img_strong = self.strong_transform(image)
                return img_weak, img_strong, label
            else: # val or test
                img_val = self.val_transform(image)
                return img_val, label
        except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                placeholder_img = torch.zeros((3, config["image_size"], config["image_size"]))
                if self.mode == 'train':
                    return placeholder_img, placeholder_img, torch.tensor(-1, dtype=torch.long) # Invalid label
                else:
                    return placeholder_img, torch.tensor(-1, dtype=torch.long)



# %% [markdown]
# #### Data Splitting and Sampler

# %%
# Load full dataset info
full_df = pd.read_csv('/content/pcos_dataset.csv')

# Split into Train (80%) and Test (20%)
train_val_indices, test_indices = train_test_split(
    range(len(full_df)),
    test_size=1.0 - config["train_test_split_ratio"],
    stratify=full_df['label'],
    random_state=config["seed"]
)

# Split Train into actual Train and Validation
train_indices, val_indices = train_test_split(
    train_val_indices,
    test_size=0.2, # 20% of the 80% train_val set for validation
    stratify=full_df.iloc[train_val_indices]['label'],
    random_state=config["seed"]
)

# Further split Train indices into Labeled and Unlabeled
labeled_indices, unlabeled_indices = train_test_split(
    train_indices,
    test_size=1.0 - config["labeled_unlabeled_split_ratio"],
    stratify=full_df.iloc[train_indices]['label'],
    random_state=config["seed"]
)

print(f"Total samples: {len(full_df)}")
print(f"Train samples: {len(train_indices)} (Labeled: {len(labeled_indices)}, Unlabeled: {len(unlabeled_indices)})")
print(f"Validation samples: {len(val_indices)}")
print(f"Test samples: {len(test_indices)}")

# Create dataset instances
train_dataset = PCOSImageDataset(csv_file='/content/pcos_dataset.csv', weak_transform=weak_transform, strong_transform=strong_transform, mode='train')
val_dataset = PCOSImageDataset(csv_file='/content/pcos_dataset.csv', val_transform=val_transform, mode='val')
test_dataset = PCOSImageDataset(csv_file='/content/pcos_dataset.csv', val_transform=val_transform, mode='test')

# Create subset datasets
val_subset = Subset(val_dataset, val_indices)
test_subset = Subset(test_dataset, test_indices)

# --- Sampler for Training ---
class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices # Labeled indices
        self.secondary_indices = secondary_indices # Unlabeled indices
        self.secondary_batch_size = secondary_batch_size # Unlabeled batch size
        self.primary_batch_size = batch_size - secondary_batch_size # Labeled batch size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size >= 0 # Allow 0 unlabeled

        print(f"Sampler: Total Batch={batch_size}, Labeled={self.primary_batch_size}, Unlabeled={self.secondary_batch_size}")


    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        len_primary_batches = len(self.primary_indices) // self.primary_batch_size
        len_secondary_batches = len(self.secondary_indices) // self.secondary_batch_size if self.secondary_batch_size > 0 else float('inf')

        num_batches = min(len_primary_batches, len_secondary_batches)
        if self.secondary_batch_size == 0 : # Handle fully supervised case
            num_batches = len_primary_batches

        # Combine batches ensuring the total length matches the minimum iterations possible
        combined_iter = (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

        # Truncate to the number of batches determined by the primary indices (labeled data dictates epoch size)
        # Or handle fully supervised case where secondary_batch_size is 0
        if self.secondary_batch_size > 0:
            return itertools.islice(combined_iter, num_batches)
        else:
            # Fully supervised: only primary batches
            primary_only_iter = (primary_batch for primary_batch in grouper(primary_iter, self.primary_batch_size))
            return itertools.islice(primary_only_iter, num_batches)


    def __len__(self):
        len_primary_batches = len(self.primary_indices) // self.primary_batch_size
        if self.secondary_batch_size == 0:
            return len_primary_batches
        len_secondary_batches = len(self.secondary_indices) // self.secondary_batch_size if self.secondary_batch_size > 0 else float('inf')
        return min(len_primary_batches, len_secondary_batches)


def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    # Use zip_longest to handle the last batch if iterable length is not divisible by n
    # Fill value is not strictly needed here as we handle the last batch explicitly if necessary
    # Or, the sampler length logic handles truncation. Sticking to zip for simplicity based on original.
    return zip(*args)

# Create the sampler and dataloaders
train_batch_sampler = TwoStreamBatchSampler(
    labeled_indices, unlabeled_indices, config["batch_size"], config["batch_size"] - config["labeled_bs"]
)

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_subset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True)

print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")
print(f"Test loader batches: {len(test_loader)}")


# %% [markdown]
# ## 4. Model Architecture (CaTS)

# %%
# --- Attention Modules (from Causal.ipynb) ---
class PositionAttention(nn.Module):
    def __init__(self, in_channels):
        super(PositionAttention, self).__init__()
        # Reduce channels for query, key, value to avoid excessive computation
        reduced_channels = max(in_channels // 8, 64) # Ensure at least 64 channels
        self.query_conv = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, height, width = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, height, width)

        out = self.gamma * out + x
        return out

class ChannelAttention(nn.Module):
    def __init__(self, in_channels):
        super(ChannelAttention, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)

    def forward(self,x):
        batch_size, C, height, width = x.size()
        proj_query = x.view(batch_size, C, -1)
        proj_key = x.view(batch_size, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        # Optional: Add scaling based on channel dimension if needed
        # energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy)
        proj_value = x.view(batch_size, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(batch_size, C, height, width)

        out = self.gamma * out + x
        return out


# --- Feature Extractor (Representation Learning Block) ---
# Adapted from Causal.ipynb and cats.jpg (b) diagram
class FeatureExtractor(nn.Module):
    def __init__(self, in_channels=2048, out_channels=1024):
        super(FeatureExtractor, self).__init__()
        self.conv_ch = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.ch_attn = ChannelAttention(out_channels) # Applied on reduced channels

        self.conv_pos = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.pos_attn = PositionAttention(out_channels) # Applied on reduced channels

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.conv_gap = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Batch Norm and ReLU can be added for stability
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Branch 1 (Attention)
        x_ch = self.conv_ch(x)
        x_ch_attn = self.ch_attn(x_ch)

        x_pos = self.conv_pos(x)
        x_pos_attn = self.pos_attn(x_pos)

        branch1_out = self.relu(self.bn1(x_ch_attn + x_pos_attn))

        # Branch 2 (GAP + Skip)
        x_gap = self.gap(x)
        x_gap = self.conv_gap(x_gap) # Shape: (B, out_channels, 1, 1)

        x_skip = self.conv_skip(x) # Shape: (B, out_channels, H, W)

        # Element-wise multiplication after broadcasting GAP features
        branch2_out = self.relu(self.bn1(x_skip * x_gap.expand_as(x_skip) + x_skip)) # Modified based on common patterns

        # Concatenate
        out = torch.cat((branch1_out, branch2_out), dim=1) # Output channels = 2 * out_channels
        return out

# --- Positional Encoding for Transformer ---
class FixedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64): # max_len = H*W of feature map
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # Shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (B, SeqLen, Dim) where SeqLen = H*W
        x = x + self.pe[:, :x.size(1)]
        return x


# --- Causal Disentanglement with Transformer Decoder ---
class CausalDisentanglement(nn.Module):
    def __init__(self, d_model, nhead, num_decoder_layers, dim_feedforward, n_queries):
        super().__init__()
        self.d_model = d_model
        self.n_queries = n_queries

        self.pos_encoder = FixedPositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Learnable causal queries
        self.causal_queries = nn.Parameter(torch.zeros(1, n_queries, d_model))

        # Layer for extracting attention weights (might need custom implementation or hooks)
        # For simplicity, we'll try to get weights if using nn.MultiheadAttention directly,
        # otherwise, this part might require modification. Let's start without explicit weight extraction.

    def forward(self, features):
        # features shape: (B, C, H, W)
        B, C, H, W = features.shape
        features_flat = features.flatten(2).permute(0, 2, 1) # Shape: (B, H*W, C)

        # Add positional encoding
        features_pos = self.pos_encoder(features_flat)

        # Prepare queries
        queries = self.causal_queries.repeat(B, 1, 1) # Shape: (B, n_queries, d_model)

        # Pass through transformer decoder
        # `memory` is the encoder output (our features_pos)
        # `tgt` is the input to the decoder (our queries)
        causal_output = self.transformer_decoder(tgt=queries, memory=features_pos)
        # causal_output shape: (B, n_queries, d_model)

        # --- Calculate Confounding Features ---
        # This part is tricky without direct access to attention weights per layer.
        # Approximation: Use the final causal output to mask the input features.
        # Project causal output back to spatial dimension attention map.
        # This is a simplification/alternative to Eq 13.
        # Another approach: Average causal_output across queries and subtract from input features.

        # Simplified approach: Average causal features, assume confound is the residual
        F_cau_avg = causal_output.mean(dim=1) # Shape: (B, d_model)

        # Average spatial features as a baseline representation
        S_avg = features_pos.mean(dim=1) # Shape: (B, d_model)

        # Confounding features (conceptual approximation)
        F_con = S_avg - F_cau_avg # Residual features

        # Return averaged features for MLPs
        return F_cau_avg, F_con


# --- MLP Heads ---
class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=2, dropout=0.5):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc1(x)
        x_intermediate = self.relu(x) # Store intermediate features (eta)
        x = self.dropout(x_intermediate)
        x_logits = self.fc2(x) # Final logits (mu)
        return x_intermediate, x_logits


# --- Complete CaTS Model ---
class CaTSModel(nn.Module):
    def __init__(self, num_classes, n_transformer_layers, n_causal_queries,
                 transformer_embed_dim, transformer_nhead, transformer_ff_dim, dropout=0.5):
        super().__init__()
        # Backbone (ResNet101) - freeze layers initially
        resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        for param in self.backbone.parameters():
            param.requires_grad = False # Freeze backbone
        self.backbone_out_channels = 2048

        # Representation Learning
        self.feature_extractor = FeatureExtractor(in_channels=self.backbone_out_channels, out_channels=self.backbone_out_channels // 2)
        self.feature_extractor_out_channels = self.backbone_out_channels # Because of concatenation

        # Causal Disentanglement
        self.causal_disentanglement = CausalDisentanglement(
            d_model=self.feature_extractor_out_channels,
            nhead=transformer_nhead,
            num_decoder_layers=n_transformer_layers,
            dim_feedforward=transformer_ff_dim,
            n_queries=n_causal_queries
        )
        self.causal_out_dim = self.feature_extractor_out_channels

        # MLP Heads
        self.mlp_cau = MLPHead(in_dim=self.causal_out_dim, out_dim=num_classes, dropout=dropout)
        self.mlp_con = MLPHead(in_dim=self.causal_out_dim, out_dim=num_classes, dropout=dropout)

    def forward(self, x):
        x = self.backbone(x)
        # print("Backbone out:", x.shape) # Should be [B, 2048, 8, 8] for 256x256 input
        x = self.feature_extractor(x)
        # print("Feature Extractor out:", x.shape) # Should be [B, 2048, 8, 8]
        F_cau_vec, F_con_vec = self.causal_disentanglement(x)
        # print("Causal/Confound Vecs:", F_cau_vec.shape, F_con_vec.shape) # Should be [B, 2048]

        eta_cau, mu_cau = self.mlp_cau(F_cau_vec) # Causal intermediate features and logits
        eta_con, mu_con = self.mlp_con(F_con_vec) # Confound intermediate features and logits

        return eta_cau, eta_con, mu_cau, mu_con


# %% [markdown]
# ## 5. Loss Functions

# %%
# --- Causal Loss (L_cau) ---
causal_loss_fn = nn.CrossEntropyLoss()

# --- Confounding Suppression Loss (L_con) ---
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
uniform_dist = torch.full((config["batch_size"], config["num_classes"]), 1.0 / config["num_classes"]).to(config["device"])

def confound_loss_fn(mu_con_stu):
    # Ensure uniform_dist matches batch size
    current_batch_size = mu_con_stu.shape[0]
    if uniform_dist.shape[0] != current_batch_size:
        u_dist = torch.full((current_batch_size, config["num_classes"]), 1.0 / config["num_classes"]).to(config["device"])
    else:
        u_dist = uniform_dist[:current_batch_size]

    log_softmax_mu_con = F.log_softmax(mu_con_stu, dim=1)
    return kl_loss_fn(log_softmax_mu_con, u_dist)

# --- Backdoor Adjustment Loss (L_b) ---
backdoor_loss_fn = nn.CrossEntropyLoss()
# Memory bank for confounding features (eta_con) from labeled data
confound_memory_bank = []
memory_bank_labels = []

def update_memory_bank(eta_con_stu_labeled, labels_labeled):
    global confound_memory_bank, memory_bank_labels
    confound_memory_bank.append(eta_con_stu_labeled.detach().cpu())
    memory_bank_labels.append(labels_labeled.detach().cpu())
    # Keep bank size limited
    if len(confound_memory_bank) > config["memory_bank_size"] // config["labeled_bs"]:
         confound_memory_bank.pop(0)
         memory_bank_labels.pop(0)


def sample_from_memory_bank(current_labels):
    if not confound_memory_bank:
        return None # Return None if bank is empty

    # Concatenate stored features and labels
    all_eta_con = torch.cat(confound_memory_bank, dim=0)
    all_labels = torch.cat(memory_bank_labels, dim=0)

    sampled_eta_con = []
    current_device = current_labels.device

    # Stratified Random Sampling (SRS)
    for i in range(len(current_labels)):
        label = current_labels[i].item()
        # Find indices in memory bank with the same label
        eligible_indices = (all_labels == label).nonzero(as_tuple=True)[0]
        if len(eligible_indices) > 0:
            # Randomly sample one index
            sampled_idx = random.choice(eligible_indices)
            sampled_eta_con.append(all_eta_con[sampled_idx])
        else:
             # If no matching label found, sample randomly from the whole bank
            if len(all_eta_con) > 0:
                 sampled_idx = random.randrange(len(all_eta_con))
                 sampled_eta_con.append(all_eta_con[sampled_idx])
            else:
                 # Should not happen if bank is not empty, but as a fallback:
                 sampled_eta_con.append(torch.zeros_like(all_eta_con[0])) # Append zeros


    if not sampled_eta_con:
         return None

    return torch.stack(sampled_eta_con).to(current_device)


# --- Consistency Loss (L_CR) ---
consistency_loss_fn = nn.MSELoss()


# %% [markdown]
# ## 6. Training Setup

# %%
# --- Initialize Models ---
student_model = CaTSModel(
    num_classes=config["num_classes"],
    n_transformer_layers=config["n_transformer_layers"],
    n_causal_queries=config["n_causal_queries"],
    transformer_embed_dim=config["transformer_embed_dim"],
    transformer_nhead=config["transformer_nhead"],
    transformer_ff_dim=config["transformer_ff_dim"],
    dropout=0.5
).to(config["device"])

teacher_model = CaTSModel(
    num_classes=config["num_classes"],
    n_transformer_layers=config["n_transformer_layers"],
    n_causal_queries=config["n_causal_queries"],
    transformer_embed_dim=config["transformer_embed_dim"],
    transformer_nhead=config["transformer_nhead"],
    transformer_ff_dim=config["transformer_ff_dim"],
    dropout=0.5
).to(config["device"])

# Initialize teacher with student weights and detach parameters
teacher_model.load_state_dict(student_model.state_dict())
for param in teacher_model.parameters():
    param.detach_()


# --- Optimizer ---
optimizer = optim.Adam(student_model.parameters(), lr=config["base_lr"], weight_decay=config["weight_decay"])

# --- EMA Update Function ---
def update_ema_variables(model, ema_model, alpha, global_step):
    # alpha is the EMA decay parameter (e.g., 0.99)
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)


global_step = 0

# %% [markdown]
# ## 7. Training Loop

# %%
print("Starting Training...")
for epoch in range(config["epochs"]):
    student_model.train()
    teacher_model.eval() # Teacher is always in eval mode

    epoch_loss = 0.0
    epoch_loss_cau = 0.0
    epoch_loss_con = 0.0
    epoch_loss_bd = 0.0
    epoch_loss_cr = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}", leave=False)

    for batch_idx, (img_weak, img_strong, labels) in enumerate(progress_bar):
        img_weak = img_weak.to(config["device"])
        img_strong = img_strong.to(config["device"])
        labels = labels.to(config["device"])

        # Split batch into labeled and unlabeled
        # Note: The sampler ensures the first `labeled_bs` are labeled indices
        img_weak_lab, img_strong_lab = img_weak[:config["labeled_bs"]], img_strong[:config["labeled_bs"]]
        img_weak_unlab, img_strong_unlab = img_weak[config["labeled_bs"]:], img_strong[config["labeled_bs"]:]
        labels_lab = labels[:config["labeled_bs"]]

        # ---------------------------------
        # Supervised Loss Calculation (L_sup) on labeled data
        # ---------------------------------
        loss_sup = torch.tensor(0.0).to(config["device"])
        loss_cau = torch.tensor(0.0).to(config["device"])
        loss_con = torch.tensor(0.0).to(config["device"])
        loss_bd = torch.tensor(0.0).to(config["device"])

        if config["labeled_bs"] > 0:
            # Use weak augmentation for supervised training part
            eta_cau_stu_lab, eta_con_stu_lab, mu_cau_stu_lab, mu_con_stu_lab = student_model(img_weak_lab)

            # 1. Causal Loss (L_cau)
            loss_cau = causal_loss_fn(mu_cau_stu_lab, labels_lab)

            # 2. Confounding Suppression Loss (L_con)
            loss_con = confound_loss_fn(mu_con_stu_lab)

            # Update memory bank for L_b
            update_memory_bank(eta_con_stu_lab, labels_lab)

            # 3. Backdoor Adjustment Loss (L_b)
            sampled_eta_con = sample_from_memory_bank(labels_lab)
            if sampled_eta_con is not None and sampled_eta_con.shape[0] == eta_cau_stu_lab.shape[0]:
                # Perturb causal features
                eta_cau_perturbed = eta_cau_stu_lab + sampled_eta_con
                # Pass *perturbed intermediate features* through the rest of the causal MLP
                # Re-apply layers after the intermediate feature extraction in MLP_cau
                x_perturbed = student_model.mlp_cau.dropout(eta_cau_perturbed) # Assuming eta is after ReLU
                mu_cau_perturbed = student_model.mlp_cau.fc2(x_perturbed)

                loss_bd = backdoor_loss_fn(mu_cau_perturbed, labels_lab)
            else:
                 loss_bd = torch.tensor(0.0).to(config["device"]) # Skip if bank is empty or size mismatch


            # Total Supervised Loss
            loss_sup = loss_cau + config["confound_lambda"] * loss_con + config["backdoor_lambda"] * loss_bd

        # ---------------------------------
        # Consistency Loss Calculation (L_CR) on unlabeled data
        # ---------------------------------
        loss_cr = torch.tensor(0.0).to(config["device"])
        num_unlabeled = img_strong_unlab.shape[0]

        if num_unlabeled > 0:
            # Student forward pass on strong augmentation
            _, _, mu_cau_stu_unlab, _ = student_model(img_strong_unlab)

            # Teacher forward pass on weak augmentation
            with torch.no_grad():
                _, _, mu_cau_tea_unlab, _ = teacher_model(img_weak_unlab)

            # Calculate Consistency Loss
            loss_cr = consistency_loss_fn(mu_cau_stu_unlab, mu_cau_tea_unlab)

        # ---------------------------------
        # Total Loss and Backpropagation
        # ---------------------------------
        total_loss = loss_sup + config["consistency_lambda"] * loss_cr

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Update Teacher Model EMA weights
        update_ema_variables(student_model, teacher_model, config["ema_decay"], global_step)

        global_step += 1

        # Logging losses
        epoch_loss += total_loss.item()
        epoch_loss_cau += loss_cau.item() if config["labeled_bs"] > 0 else 0
        epoch_loss_con += loss_con.item() if config["labeled_bs"] > 0 else 0
        epoch_loss_bd += loss_bd.item() if config["labeled_bs"] > 0 and sampled_eta_con is not None else 0
        epoch_loss_cr += loss_cr.item() if num_unlabeled > 0 else 0

        progress_bar.set_postfix({
            'Loss': f"{total_loss.item():.4f}",
            'L_cau': f"{loss_cau.item():.4f}" if config["labeled_bs"] > 0 else "N/A",
            'L_con': f"{loss_con.item():.4f}" if config["labeled_bs"] > 0 else "N/A",
            'L_bd': f"{loss_bd.item():.4f}" if config["labeled_bs"] > 0 and sampled_eta_con is not None else "N/A",
            'L_cr': f"{loss_cr.item():.4f}" if num_unlabeled > 0 else "N/A"
        })

    # --- End of Epoch ---
    avg_epoch_loss = epoch_loss / len(train_loader)
    avg_loss_cau = epoch_loss_cau / len(train_loader)
    avg_loss_con = epoch_loss_con / len(train_loader)
    avg_loss_bd = epoch_loss_bd / len(train_loader)
    avg_loss_cr = epoch_loss_cr / len(train_loader)

    # Log average epoch losses to TensorBoard
    writer.add_scalar('Loss/Train_Total', avg_epoch_loss, epoch)
    writer.add_scalar('Loss/Train_Causal', avg_loss_cau, epoch)
    writer.add_scalar('Loss/Train_Confound', avg_loss_con, epoch)
    writer.add_scalar('Loss/Train_Backdoor', avg_loss_bd, epoch)
    writer.add_scalar('Loss/Train_Consistency', avg_loss_cr, epoch)

    print(f"\nEpoch {epoch+1} Average Loss: {avg_epoch_loss:.4f} [L_cau:{avg_loss_cau:.4f}, L_con:{avg_loss_con:.4f}, L_bd:{avg_loss_bd:.4f}, L_cr:{avg_loss_cr:.4f}]")

    # --- Validation ---
    student_model.eval()
    val_loss = 0.0
    all_preds_val = []
    all_labels_val = []

    with torch.no_grad():
        for img_val, labels_val in val_loader:
            img_val, labels_val = img_val.to(config["device"]), labels_val.to(config["device"])
            _, _, mu_cau_val, _ = student_model(img_val)
            loss = causal_loss_fn(mu_cau_val, labels_val) # Use causal loss for validation
            val_loss += loss.item()

            preds = torch.argmax(mu_cau_val, dim=1)
            all_preds_val.extend(preds.cpu().numpy())
            all_labels_val.extend(labels_val.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels_val, all_preds_val)
    val_precision = precision_score(all_labels_val, all_preds_val, average='binary', zero_division=0)
    val_recall = recall_score(all_labels_val, all_preds_val, average='binary', zero_division=0)
    val_f1 = f1_score(all_labels_val, all_preds_val, average='binary', zero_division=0)

    writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
    writer.add_scalar('Metrics/Val_Accuracy', val_accuracy, epoch)
    writer.add_scalar('Metrics/Val_Precision', val_precision, epoch)
    writer.add_scalar('Metrics/Val_Recall', val_recall, epoch)
    writer.add_scalar('Metrics/Val_F1', val_f1, epoch)

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")

    # Save checkpoint
    if (epoch + 1) % 10 == 0 or epoch == config["epochs"] - 1:
        checkpoint_path = os.path.join(config["output_dir"], "checkpoints", f"cats_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'student_state_dict': student_model.state_dict(),
            'teacher_state_dict': teacher_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step,
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

writer.close()
print("Training Finished.")


# %% [markdown]
# ## 8. Evaluation on Test Set

# %%
print("Evaluating on Test Set...")
student_model.eval() # Use the final student model for testing
test_loss = 0.0
all_preds_test = []
all_labels_test = []

with torch.no_grad():
    for img_test, labels_test in tqdm(test_loader, desc="Testing"):
        img_test, labels_test = img_test.to(config["device"]), labels_test.to(config["device"])

        # Forward pass through the student model's causal path
        _, _, mu_cau_test, _ = student_model(img_test)

        loss = causal_loss_fn(mu_cau_test, labels_test)
        test_loss += loss.item()

        preds = torch.argmax(mu_cau_test, dim=1)
        all_preds_test.extend(preds.cpu().numpy())
        all_labels_test.extend(labels_test.cpu().numpy())

avg_test_loss = test_loss / len(test_loader)
test_accuracy = accuracy_score(all_labels_test, all_preds_test)
test_precision = precision_score(all_labels_test, all_preds_test, average='binary', zero_division=0)
test_recall = recall_score(all_labels_test, all_preds_test, average='binary', zero_division=0)
test_f1 = f1_score(all_labels_test, all_preds_test, average='binary', zero_division=0)

print("\n--- Test Set Results ---")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Accuracy: {test_accuracy:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1 Score: {test_f1:.4f}")

# %% [markdown]
# ## 9. Optional: Save Final Model

# %%
final_model_path = os.path.join(config["output_dir"], "cats_final_student_model.pth")
torch.save(student_model.state_dict(), final_model_path)
print(f"Final student model saved to {final_model_path}")