In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, Sampler
from PIL import Image
import os
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np
import random
import copy
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import itertools
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

def seed_everything(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

config = {
    "image_size": 256,
    "batch_size": 16,
    "labeled_bs_ratio": 0.5,
    "num_classes": 2,
    "base_lr": 1e-3,
    "weight_decay": 1e-2,
    "epochs": 100,
    "ema_decay": 0.99,
    "consistency_lambda": 1.0,
    "confound_lambda": 1.0,
    "backdoor_lambda": 1.0,
    "n_transformer_layers": 6,
    "n_causal_queries": 8,
    "transformer_embed_dim": 2048,
    "transformer_nhead": 8,
    "transformer_ff_dim": 2048,
    "memory_bank_size": 1024,
    "train_test_split_ratio": 0.8,
    "labeled_unlabeled_split_ratio": 0.2,
    "seed": 1337,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "data_root": "INPUT PATH",
    "output_dir": "OUTPUT PATH"
}

config["labeled_bs"] = int(config["batch_size"] * config["labeled_bs_ratio"])

if not os.path.exists('../pcos_dataset.csv'):
    image_paths = []
    labels = []
    class_map = {'PCOS_positive': 1, 'PCOS_negative': 0}
    positive_folder = os.path.join(config["data_root"], '../PCOSGen-train/images')
    negative_folder = os.path.join(config["data_root"], '../PCOSGen-train/images')

    all_files = [f for f in os.listdir(positive_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for filename in all_files:
        full_path = os.path.join(positive_folder, filename)
        if 'pco' in filename.lower() or 'polycystic' in filename.lower() or 'infected' in filename.lower():
             if os.path.exists(full_path):
                 image_paths.append(full_path)
                 labels.append(1)
        elif 'normal' in filename.lower() or 'notinfected' in filename.lower():
            if os.path.exists(full_path):
                image_paths.append(full_path)
                labels.append(0)

    if not image_paths:
         raise FileNotFoundError("No images found. Check data_root and folder structure.")

    df = pd.DataFrame({'image_path': image_paths, 'label': labels})
    df.to_csv('../pcos_dataset.csv', index=False)

weak_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(config["image_size"], padding=int(config["image_size"]*0.125), padding_mode='reflect'),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

strong_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.RandAugment(num_ops=2, magnitude=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((config["image_size"], config["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class PCOSImageDataset(Dataset):
    def __init__(self, csv_file, weak_transform=None, strong_transform=None, val_transform=None, mode='train'):
        self.dataframe = pd.read_csv(csv_file)
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform
        self.val_transform = val_transform
        self.mode = mode

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        label = torch.tensor(self.dataframe.iloc[idx]['label'], dtype=torch.long)

        try:
            if not os.path.exists(img_path):
                 placeholder_img = torch.zeros((3, config["image_size"], config["image_size"]))
                 if self.mode == 'train':
                    return placeholder_img, placeholder_img, torch.tensor(-1, dtype=torch.long)
                 else:
                    return placeholder_img, torch.tensor(-1, dtype=torch.long)

            image = Image.open(img_path).convert('RGB')

            if self.mode == 'train':
                img_weak = self.weak_transform(image)
                img_strong = self.strong_transform(image)
                return img_weak, img_strong, label
            else:
                img_val = self.val_transform(image)
                return img_val, label
        except Exception as e:
                placeholder_img = torch.zeros((3, config["image_size"], config["image_size"]))
                if self.mode == 'train':
                    return placeholder_img, placeholder_img, torch.tensor(-1, dtype=torch.long)
                else:
                    return placeholder_img, torch.tensor(-1, dtype=torch.long)

full_df = pd.read_csv('../pcos_dataset.csv')

train_val_indices, test_indices = train_test_split(
    range(len(full_df)),
    test_size=1.0 - config["train_test_split_ratio"],
    stratify=full_df['label'],
    random_state=config["seed"]
)

train_indices, val_indices = train_test_split(
    train_val_indices,
    test_size=0.2,
    stratify=full_df.iloc[train_val_indices]['label'],
    random_state=config["seed"]
)

labeled_indices, unlabeled_indices = train_test_split(
    train_indices,
    test_size=1.0 - config["labeled_unlabeled_split_ratio"],
    stratify=full_df.iloc[train_indices]['label'],
    random_state=config["seed"]
)

train_dataset = PCOSImageDataset(csv_file='../pcos_dataset.csv', weak_transform=weak_transform, strong_transform=strong_transform, mode='train')
val_dataset = PCOSImageDataset(csv_file='../pcos_dataset.csv', val_transform=val_transform, mode='val')
test_dataset = PCOSImageDataset(csv_file='../pcos_dataset.csv', val_transform=val_transform, mode='test')

val_subset = Subset(val_dataset, val_indices)
test_subset = Subset(test_dataset, test_indices)

class TwoStreamBatchSampler(Sampler):
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size >= 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        len_primary_batches = len(self.primary_indices) // self.primary_batch_size
        len_secondary_batches = len(self.secondary_indices) // self.secondary_batch_size if self.secondary_batch_size > 0 else float('inf')

        num_batches = min(len_primary_batches, len_secondary_batches)
        if self.secondary_batch_size == 0 :
            num_batches = len_primary_batches

        combined_iter = (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

        if self.secondary_batch_size > 0:
            return itertools.islice(combined_iter, num_batches)
        else:
            primary_only_iter = (primary_batch for primary_batch in grouper(primary_iter, self.primary_batch_size))
            return itertools.islice(primary_only_iter, num_batches)

    def __len__(self):
        len_primary_batches = len(self.primary_indices) // self.primary_batch_size
        if self.secondary_batch_size == 0:
            return len_primary_batches
        len_secondary_batches = len(self.secondary_indices) // self.secondary_batch_size if self.secondary_batch_size > 0 else float('inf')
        return min(len_primary_batches, len_secondary_batches)

def iterate_once(iterable):
    return np.random.permutation(iterable)

def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n):
    args = [iter(iterable)] * n
    return zip(*args)

train_batch_sampler = TwoStreamBatchSampler(
    labeled_indices, unlabeled_indices, config["batch_size"], config["batch_size"] - config["labeled_bs"]
)

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_subset, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=True)

class PositionAttention(nn.Module):
    def __init__(self, in_channels):
        super(PositionAttention, self).__init__()
        reduced_channels = max(in_channels // 8, 64)
        self.query_conv = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, reduced_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, height, width = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, height, width)

        out = self.gamma * out + x
        return out

class ChannelAttention(nn.Module):
    def __init__(self, in_channels):
        super(ChannelAttention, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)

    def forward(self,x):
        batch_size, C, height, width = x.size()
        proj_query = x.view(batch_size, C, -1)
        proj_key = x.view(batch_size, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = x.view(batch_size, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(batch_size, C, height, width)

        out = self.gamma * out + x
        return out

class FeatureExtractor(nn.Module):
    def __init__(self, in_channels=2048, out_channels=1024):
        super(FeatureExtractor, self).__init__()
        self.conv_ch = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.ch_attn = ChannelAttention(out_channels)

        self.conv_pos = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.pos_attn = PositionAttention(out_channels)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.conv_gap = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x_ch = self.conv_ch(x)
        x_ch_attn = self.ch_attn(x_ch)

        x_pos = self.conv_pos(x)
        x_pos_attn = self.pos_attn(x_pos)

        branch1_out = self.relu(self.bn1(x_ch_attn + x_pos_attn))

        x_gap = self.gap(x)
        x_gap = self.conv_gap(x_gap)

        x_skip = self.conv_skip(x)

        branch2_out = self.relu(self.bn1(x_skip * x_gap.expand_as(x_skip) + x_skip))

        out = torch.cat((branch1_out, branch2_out), dim=1)
        return out

class FixedPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class CausalDisentanglement(nn.Module):
    def __init__(self, d_model, nhead, num_decoder_layers, dim_feedforward, n_queries):
        super().__init__()
        self.d_model = d_model
        self.n_queries = n_queries

        self.pos_encoder = FixedPositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.causal_queries = nn.Parameter(torch.zeros(1, n_queries, d_model))

    def forward(self, features):
        B, C, H, W = features.shape
        features_flat = features.flatten(2).permute(0, 2, 1)

        features_pos = self.pos_encoder(features_flat)

        queries = self.causal_queries.repeat(B, 1, 1)

        causal_output = self.transformer_decoder(tgt=queries, memory=features_pos)

        F_cau_avg = causal_output.mean(dim=1)

        S_avg = features_pos.mean(dim=1)

        F_con = S_avg - F_cau_avg

        return F_cau_avg, F_con

class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=2, dropout=0.5):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc1(x)
        x_intermediate = self.relu(x)
        x = self.dropout(x_intermediate)
        x_logits = self.fc2(x)
        return x_intermediate, x_logits

class CaTSModel(nn.Module):
    def __init__(self, num_classes, n_transformer_layers, n_causal_queries,
                 transformer_embed_dim, transformer_nhead, transformer_ff_dim, dropout=0.5):
        super().__init__()
        resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone_out_channels = 2048

        self.feature_extractor = FeatureExtractor(in_channels=self.backbone_out_channels, out_channels=self.backbone_out_channels // 2)
        self.feature_extractor_out_channels = self.backbone_out_channels

        self.causal_disentanglement = CausalDisentanglement(
            d_model=self.feature_extractor_out_channels,
            nhead=transformer_nhead,
            num_decoder_layers=n_transformer_layers,
            dim_feedforward=transformer_ff_dim,
            n_queries=n_causal_queries
        )
        self.causal_out_dim = self.feature_extractor_out_channels

        self.mlp_cau = MLPHead(in_dim=self.causal_out_dim, out_dim=num_classes, dropout=dropout)
        self.mlp_con = MLPHead(in_dim=self.causal_out_dim, out_dim=num_classes, dropout=dropout)

    def forward(self, x):
        x = self.backbone(x)
        x = self.feature_extractor(x)
        F_cau_vec, F_con_vec = self.causal_disentanglement(x)
        eta_cau, mu_cau = self.mlp_cau(F_cau_vec)
        eta_con, mu_con = self.mlp_con(F_con_vec)
        return eta_cau, eta_con, mu_cau, mu_con

causal_loss_fn = nn.CrossEntropyLoss()

kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
uniform_dist = torch.full((config["batch_size"], config["num_classes"]), 1.0 / config["num_classes"]).to(config["device"])

def confound_loss_fn(mu_con_stu):
    current_batch_size = mu_con_stu.shape[0]
    if uniform_dist.shape[0] != current_batch_size:
        u_dist = torch.full((current_batch_size, config["num_classes"]), 1.0 / config["num_classes"]).to(config["device"])
    else:
        u_dist = uniform_dist[:current_batch_size]

    log_softmax_mu_con = F.log_softmax(mu_con_stu, dim=1)
    return kl_loss_fn(log_softmax_mu_con, u_dist)

backdoor_loss_fn = nn.CrossEntropyLoss()
confound_memory_bank = []
memory_bank_labels = []

def update_memory_bank(eta_con_stu_labeled, labels_labeled):
    global confound_memory_bank, memory_bank_labels
    confound_memory_bank.append(eta_con_stu_labeled.detach().cpu())
    memory_bank_labels.append(labels_labeled.detach().cpu())
    if len(confound_memory_bank) > config["memory_bank_size"] // config["labeled_bs"]:
         confound_memory_bank.pop(0)
         memory_bank_labels.pop(0)

def sample_from_memory_bank(current_labels):
    if not confound_memory_bank:
        return None

    all_eta_con = torch.cat(confound_memory_bank, dim=0)
    all_labels = torch.cat(memory_bank_labels, dim=0)

    sampled_eta_con = []
    current_device = current_labels.device

    for i in range(len(current_labels)):
        label = current_labels[i].item()
        eligible_indices = (all_labels == label).nonzero(as_tuple=True)[0]
        if len(eligible_indices) > 0:
            sampled_idx = random.choice(eligible_indices)
            sampled_eta_con.append(all_eta_con[sampled_idx])
        else:
            if len(all_eta_con) > 0:
                 sampled_idx = random.randrange(len(all_eta_con))
                 sampled_eta_con.append(all_eta_con[sampled_idx])
            else:
                 sampled_eta_con.append(torch.zeros_like(all_eta_con[0]))

    if not sampled_eta_con:
         return None

    return torch.stack(sampled_eta_con).to(current_device)

consistency_loss_fn = nn.MSELoss()

student_model = CaTSModel(
    num_classes=config["num_classes"],
    n_transformer_layers=config["n_transformer_layers"],
    n_causal_queries=config["n_causal_queries"],
    transformer_embed_dim=config["transformer_embed_dim"],
    transformer_nhead=config["transformer_nhead"],
    transformer_ff_dim=config["transformer_ff_dim"],
    dropout=0.5
).to(config["device"])

teacher_model = CaTSModel(
    num_classes=config["num_classes"],
    n_transformer_layers=config["n_transformer_layers"],
    n_causal_queries=config["n_causal_queries"],
    transformer_embed_dim=config["transformer_embed_dim"],
    transformer_nhead=config["transformer_nhead"],
    transformer_ff_dim=config["transformer_ff_dim"],
    dropout=0.5
).to(config["device"])

teacher_model.load_state_dict(student_model.state_dict())
for param in teacher_model.parameters():
    param.detach_()

optimizer = optim.Adam(student_model.parameters(), lr=config["base_lr"], weight_decay=config["weight_decay"])

def update_ema_variables(model, ema_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

global_step = 0

for epoch in range(config["epochs"]):
    student_model.train()
    teacher_model.eval()

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}", leave=False)

    for batch_idx, (img_weak, img_strong, labels) in enumerate(progress_bar):
        img_weak = img_weak.to(config["device"])
        img_strong = img_strong.to(config["device"])
        labels = labels.to(config["device"])

        img_weak_lab, img_strong_lab = img_weak[:config["labeled_bs"]], img_strong[:config["labeled_bs"]]
        img_weak_unlab, img_strong_unlab = img_weak[config["labeled_bs"]:], img_strong[config["labeled_bs"]:]
        labels_lab = labels[:config["labeled_bs"]]

        loss_sup = torch.tensor(0.0).to(config["device"])
        loss_cau = torch.tensor(0.0).to(config["device"])
        loss_con = torch.tensor(0.0).to(config["device"])
        loss_bd = torch.tensor(0.0).to(config["device"])

        if config["labeled_bs"] > 0:
            eta_cau_stu_lab, eta_con_stu_lab, mu_cau_stu_lab, mu_con_stu_lab = student_model(img_weak_lab)
            loss_cau = causal_loss_fn(mu_cau_stu_lab, labels_lab)
            loss_con = confound_loss_fn(mu_con_stu_lab)
            update_memory_bank(eta_con_stu_lab, labels_lab)
            sampled_eta_con = sample_from_memory_bank(labels_lab)
            if sampled_eta_con is not None and sampled_eta_con.shape[0] == eta_cau_stu_lab.shape[0]:
                eta_cau_perturbed = eta_cau_stu_lab + sampled_eta_con
                x_perturbed = student_model.mlp_cau.dropout(eta_cau_perturbed)
                mu_cau_perturbed = student_model.mlp_cau.fc2(x_perturbed)
                loss_bd = backdoor_loss_fn(mu_cau_perturbed, labels_lab)
            else:
                 loss_bd = torch.tensor(0.0).to(config["device"])

            loss_sup = loss_cau + config["confound_lambda"] * loss_con + config["backdoor_lambda"] * loss_bd

        loss_cr = torch.tensor(0.0).to(config["device"])
        num_unlabeled = img_strong_unlab.shape[0]

        if num_unlabeled > 0:
            _, _, mu_cau_stu_unlab, _ = student_model(img_strong_unlab)
            with torch.no_grad():
                _, _, mu_cau_tea_unlab, _ = teacher_model(img_weak_unlab)
            loss_cr = consistency_loss_fn(mu_cau_stu_unlab, mu_cau_tea_unlab)

        total_loss = loss_sup + config["consistency_lambda"] * loss_cr

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        update_ema_variables(student_model, teacher_model, config["ema_decay"], global_step)
        global_step += 1

        progress_bar.set_postfix({
            'Loss': f"{total_loss.item():.4f}",
            'L_cau': f"{loss_cau.item():.4f}" if config["labeled_bs"] > 0 else "N/A",
            'L_con': f"{loss_con.item():.4f}" if config["labeled_bs"] > 0 else "N/A",
            'L_bd': f"{loss_bd.item():.4f}" if config["labeled_bs"] > 0 and sampled_eta_con is not None else "N/A",
            'L_cr': f"{loss_cr.item():.4f}" if num_unlabeled > 0 else "N/A"
        })

    student_model.eval()
    val_loss = 0.0
    all_preds_val = []
    all_labels_val = []

    with torch.no_grad():
        for img_val, labels_val in val_loader:
            img_val, labels_val = img_val.to(config["device"]), labels_val.to(config["device"])
            _, _, mu_cau_val, _ = student_model(img_val)
            loss = causal_loss_fn(mu_cau_val, labels_val)
            val_loss += loss.item()

            preds = torch.argmax(mu_cau_val, dim=1)
            all_preds_val.extend(preds.cpu().numpy())
            all_labels_val.extend(labels_val.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels_val, all_preds_val)
    val_precision = precision_score(all_labels_val, all_preds_val, average='binary', zero_division=0)
    val_recall = recall_score(all_labels_val, all_preds_val, average='binary', zero_division=0)
    val_f1 = f1_score(all_labels_val, all_preds_val, average='binary', zero_division=0)

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")

student_model.eval()
test_loss = 0.0
all_preds_test = []
all_labels_test = []

with torch.no_grad():
    for img_test, labels_test in tqdm(test_loader, desc="Testing"):
        img_test, labels_test = img_test.to(config["device"]), labels_test.to(config["device"])

        _, _, mu_cau_test, _ = student_model(img_test)

        loss = causal_loss_fn(mu_cau_test, labels_test)
        test_loss += loss.item()

        preds = torch.argmax(mu_cau_test, dim=1)
        all_preds_test.extend(preds.cpu().numpy())
        all_labels_test.extend(labels_test.cpu().numpy())

avg_test_loss = test_loss / len(test_loader)
test_accuracy = accuracy_score(all_labels_test, all_preds_test)
test_precision = precision_score(all_labels_test, all_preds_test, average='binary', zero_division=0)
test_recall = recall_score(all_labels_test, all_preds_test, average='binary', zero_division=0)
test_f1 = f1_score(all_labels_test, all_preds_test, average='binary', zero_division=0)

print("\n--- Test Set Results ---")
print(f"Accuracy: {test_accuracy:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1 Score: {test_f1:.4f}")