In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode


# Common transforms for images
transform = T.Compose([
    T.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    T.ToTensor(),
    # Optionally, normalize here
    # T.Normalize(mean=[0.485, 0.456, 0.406], 
    #             std=[0.229, 0.224, 0.225])
])

# Common transforms for masks
transform_mask = T.Compose([
    T.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    T.ToTensor()
])

class MyLabeledDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, transform_mask=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.transform_mask = transform_mask
        
        self.images = sorted([
            os.path.join(image_dir, x) 
            for x in os.listdir(image_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])
        
        self.masks = sorted([
            os.path.join(mask_dir, x) 
            for x in os.listdir(mask_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")  # or "L" if grayscale

        # Load corresponding mask
        mask_path = self.masks[idx]
        mask = Image.open(mask_path).convert("L")  # single-channel

        if self.transform:
            image = self.transform(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)

        # Binarize the mask (assuming foreground is > 0.5)
        mask = (mask > 0.5).long()  
        mask = mask.squeeze(0)     

        return image, mask

class MyUnlabeledDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = sorted([
            os.path.join(image_dir, x) 
            for x in os.listdir(image_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image


class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(SimpleUNet, self).__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(64+64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.seg_head = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        b = self.bottleneck(p2)

        u2 = self.up2(b)
        cat2 = torch.cat([u2, e2], dim=1)
        d2 = self.dec2(cat2)

        u1 = self.up1(d2)
        cat1 = torch.cat([u1, e1], dim=1)
        d1 = self.dec1(cat1)

        return self.seg_head(d1)  # shape: [N, out_channels, H, W]


class UltraSemiNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=3, alpha=0.99):
        super(UltraSemiNet, self).__init__()
        # Student and Teacher share the same architecture
        self.student_net = SimpleUNet(in_channels, num_classes)
        self.teacher_net = SimpleUNet(in_channels, num_classes)
        
        # Initialize teacher weights to match student initially
        self._update_teacher(0.0)
        # Exponential moving average factor for teacher update
        self.alpha = alpha
        self.num_classes = num_classes

    def forward(self, x):
        return self.student_net(x)

    @torch.no_grad()
    def _update_teacher(self, alpha=None):
        if alpha is None:
            alpha = self.alpha
        for teacher_param, student_param in zip(self.teacher_net.parameters(), 
                                                self.student_net.parameters()):
            teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data


def sat_loss(anchor, pos, neg, temperature=0.07):
    # Dot products for anchor-positive and anchor-negative
    sim_pos = (anchor * pos).sum(dim=1) / temperature
    sim_neg = (anchor * neg).sum(dim=1) / temperature
    
    # SAT Loss is based on softmax cross-entropy
    logits = torch.stack([sim_pos, sim_neg], dim=1)
    labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)
    return F.cross_entropy(logits, labels)

def compute_hardness_map(logits):
    probs = torch.softmax(logits, dim=1)
    ent = -torch.sum(probs * torch.log(probs + 1e-8), dim=1, keepdim=True)
    return ent

def aldc_loss(features, labels, mask, temperature=0.07):
    B, C, H, W = features.shape
    features_2d = features.permute(0,2,3,1).reshape(-1, C) 
    labels_2d = labels.reshape(-1)     
    mask_2d = mask.reshape(-1) > 0.5

    idxs = torch.where(mask_2d)[0]
    if len(idxs) < 2:
        return torch.tensor(0.0, device=features.device)  # no "hard" region

    # Pick random anchor in the masked region
    anchor_idx = idxs[torch.randint(0, len(idxs), (1,))]
    anchor_feat = features_2d[anchor_idx]  # shape (C,)
    anchor_label = labels_2d[anchor_idx]
    same_label_idx = idxs[(labels_2d[idxs] == anchor_label)]
    if len(same_label_idx) < 2:
        return torch.tensor(0.0, device=features.device)
    pos_idx = same_label_idx[torch.randint(0, len(same_label_idx), (1,))]
    pos_feat = features_2d[pos_idx]
    diff_label_idx = idxs[(labels_2d[idxs] != anchor_label)]
    if len(diff_label_idx) < 1:
        return torch.tensor(0.0, device=features.device)
    neg_idx = diff_label_idx[torch.randint(0, len(diff_label_idx), (1,))]
    neg_feat = features_2d[neg_idx]

    # Same form as sat_loss
    sim_pos = (anchor_feat * pos_feat).sum() / temperature
    sim_neg = (anchor_feat * neg_feat).sum() / temperature
    logits = torch.stack([sim_pos, sim_neg], dim=0).unsqueeze(0)
    labels_val = torch.zeros(1, dtype=torch.long, device=features.device)

    return F.cross_entropy(logits, labels_val)


def train_ultraseminet(
    student_teacher_model,
    dataloader_labeled,
    dataloader_unlabeled,
    optimizer,
    num_epochs=10,
    temperature=0.07,
    lambda_sat=0.5,
    lambda_aldc=0.5,
    save_path="model.pth"
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_teacher_model = student_teacher_model.to(device)
    criterion_ce = nn.CrossEntropyLoss()
    best_loss = float("inf")

    for epoch in range(num_epochs):
        student_teacher_model.train()
        running_loss = 0.0
        steps = 0
        steps_per_epoch = min(len(dataloader_labeled), len(dataloader_unlabeled))

        # Create the tqdm progress bar
        pbar = tqdm(
            zip(dataloader_labeled, dataloader_unlabeled),
            total=steps_per_epoch,
            desc=f"Epoch {epoch+1}/{num_epochs}"
        )
        for (x_l, y_l), x_u in pbar:
            x_l, y_l = x_l.to(device), y_l.to(device)
            x_u = x_u.to(device)

            # Supervised loss
            logits_l = student_teacher_model(x_l)
            sup_loss = criterion_ce(logits_l, y_l)

            # Pseudo-label generation (teacher side)
            with torch.no_grad():
                logits_u_teacher = student_teacher_model.teacher_net(x_u)
                pseudo_labels = torch.argmax(logits_u_teacher, dim=1)
            
            # Student forward on unlabeled data
            logits_u_student = student_teacher_model(x_u)
            unsup_loss_ce = criterion_ce(logits_u_student, pseudo_labels)

            # Features for SAT loss
            features_student_u = F.adaptive_avg_pool2d(logits_u_student, (1,1)).squeeze(-1).squeeze(-1)
            features_teacher_u = F.adaptive_avg_pool2d(logits_u_teacher, (1,1)).squeeze(-1).squeeze(-1)

            # Negative examples by shuffling
            batch_size = features_student_u.size(0)
            indices = torch.randperm(batch_size, device=device)
            neg_features = features_student_u[indices]

            # SAT loss
            sat_loss_val = sat_loss(features_student_u, features_teacher_u, neg_features, temperature)

            # ALDC loss on labeled data (using hardness map from teacher)
            with torch.no_grad():
                logits_l_hard = student_teacher_model.teacher_net(x_l)
            hardness_map = compute_hardness_map(logits_l_hard)
            mask = (hardness_map > 0.5).float()

            aldc_val = aldc_loss(logits_l, y_l.unsqueeze(1), mask, temperature)

            # Total loss
            total_loss = sup_loss + unsup_loss_ce + lambda_sat*sat_loss_val + lambda_aldc*aldc_val

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # EMA update for teacher
            student_teacher_model._update_teacher()

            running_loss += total_loss.item()
            steps += 1

            # Update the progress bar
            pbar.set_postfix({
                "SupLoss": f"{sup_loss.item():.4f}",
                "UnsupLoss": f"{unsup_loss_ce.item():.4f}",
                "SAT": f"{sat_loss_val.item():.4f}",
                "ALDC": f"{aldc_val.item():.4f}",
                "Total": f"{total_loss.item():.4f}"
            })

        epoch_loss = running_loss / steps if steps > 0 else 0.0
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}")

        # Save best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(student_teacher_model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with loss={epoch_loss:.4f}")

    print("Training complete!")


labeled_image_dir = 'dataset/labeled_data/images'
labeled_mask_dir = 'dataset/labeled_data/labels'
unlabeled_image_dir = 'dataset/unlabeled_data/images'
labeled_dataset = MyLabeledDataset(
    image_dir=labeled_image_dir,
    mask_dir=labeled_mask_dir,
    transform=transform,
    transform_mask=transform_mask
)
unlabeled_dataset = MyUnlabeledDataset(
    image_dir=unlabeled_image_dir,
    transform=transform
)
labeled_loader = DataLoader(labeled_dataset, batch_size=4, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=4, shuffle=True)
model = UltraSemiNet(in_channels=3, num_classes=3, alpha=0.99)  # or in_channels=1 if grayscale
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_ultraseminet(
    student_teacher_model=model,
    dataloader_labeled=labeled_loader,
    dataloader_unlabeled=unlabeled_loader,
    optimizer=optimizer,
    num_epochs=10,
    temperature=0.07,
    lambda_sat=0.5,
    lambda_aldc=0.5,
    save_path="model.pth"
)


Epoch 1/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  9.04it/s, SupLoss=1.2245, UnsupLoss=1.0118, SAT=0.6378, ALDC=0.0000, Total=2.5552]


[Epoch 1/10] Loss: 2.5828
Model saved at epoch 1 with loss=2.5828


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 15.69it/s, SupLoss=1.1659, UnsupLoss=1.0160, SAT=0.6303, ALDC=0.0000, Total=2.4970]


[Epoch 2/10] Loss: 2.5214
Model saved at epoch 2 with loss=2.5214


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 16.14it/s, SupLoss=1.1104, UnsupLoss=1.0156, SAT=0.6918, ALDC=0.0000, Total=2.4720]


[Epoch 3/10] Loss: 2.4857
Model saved at epoch 3 with loss=2.4857


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.90it/s, SupLoss=1.1134, UnsupLoss=1.0172, SAT=0.7012, ALDC=0.0000, Total=2.4812]


[Epoch 4/10] Loss: 2.4830
Model saved at epoch 4 with loss=2.4830


Epoch 5/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 13.00it/s, SupLoss=1.1092, UnsupLoss=1.0182, SAT=0.6989, ALDC=0.0000, Total=2.4769]


[Epoch 5/10] Loss: 2.4815
Model saved at epoch 5 with loss=2.4815


Epoch 6/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.03it/s, SupLoss=1.1084, UnsupLoss=1.0190, SAT=0.7056, ALDC=0.0000, Total=2.4802]


[Epoch 6/10] Loss: 2.4796
Model saved at epoch 6 with loss=2.4796


Epoch 7/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.07it/s, SupLoss=1.1049, UnsupLoss=1.0212, SAT=0.7304, ALDC=0.0000, Total=2.4913]


[Epoch 7/10] Loss: 2.4801


Epoch 8/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 13.04it/s, SupLoss=1.1036, UnsupLoss=1.0203, SAT=0.7219, ALDC=0.0000, Total=2.4849]


[Epoch 8/10] Loss: 2.4813


Epoch 9/10: 100%|██████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.94it/s, SupLoss=1.0963, UnsupLoss=1.0207, SAT=0.7269, ALDC=0.0000, Total=2.4804]


[Epoch 9/10] Loss: 2.4800


Epoch 10/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.98it/s, SupLoss=1.0906, UnsupLoss=1.0225, SAT=0.7369, ALDC=0.0000, Total=2.4816]

[Epoch 10/10] Loss: 2.4773
Model saved at epoch 10 with loss=2.4773
Training complete!





In [6]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import argparse

# ----------------------
# Enhanced Transforms
# ----------------------
class RandomCutMix:
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, img):
        if np.random.rand() > 0.3:  # 30% probability
            return img
            
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = img.size(0)
        index = torch.randperm(batch_size)
        
        # Create mixed image
        bbx1, bby1, bbx2, bby2 = self.rand_bbox(img.size(), lam)
        img[:, :, bbx1:bbx2, bby1:bby2] = img[index, :, bbx1:bbx2, bby1:bby2]
        return img

    @staticmethod
    def rand_bbox(size, lam):
        W, H = size[2], size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)

        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

transform_labeled = T.Compose([
    T.Resize((256, 256)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(15),
    T.ColorJitter(0.2, 0.2, 0.2, 0.1),
    T.ToTensor(),
])

transform_unlabeled = T.Compose([
    T.Resize((256, 256)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    T.ColorJitter(0.3, 0.3, 0.3, 0.2),
    T.GaussianBlur(3, sigma=(0.1, 2.0)),
    T.ToTensor(),
    RandomCutMix(alpha=1.0),
    T.RandomErasing(p=0.2, scale=(0.02, 0.2), ratio=(0.3, 3.3)),
])

transform_mask = T.Compose([
    T.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    T.PILToTensor(),
    lambda x: x.squeeze(0).long()
])

# ----------------------
# Dataset Classes
# ----------------------
class CervicalDataset(Dataset):
    def __init__(self, img_dir, mask_dir=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.is_labeled = mask_dir is not None
        self.image_paths = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) 
                              if f.endswith(('.png', '.jpg'))])
        if self.is_labeled:
            self.mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) 
                                 if f.endswith(('.png', '.jpg'))])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.is_labeled:
            mask = Image.open(self.mask_paths[idx]).convert("L")
            return transform_labeled(image), transform_mask(mask)
        return transform_unlabeled(image)

# ----------------------
# Enhanced U-Net Model
# ----------------------
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        att = self.conv(x)
        return x * self.sigmoid(att)

class CervicalUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)
        
        # Bottleneck with attention
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.attention = AttentionBlock(256)
        self.feature_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.seg_head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.BatchNorm2d(num_classes),
            nn.ReLU(inplace=True)
        )

    def forward_features(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        bn = self.bottleneck(p2)
        att = self.attention(bn)
        return self.feature_pool(att).squeeze(-1).squeeze(-1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        # Bottleneck
        bn = self.bottleneck(p2)
        bn = self.attention(bn)
        
        # Decoder
        d1 = self.up1(bn)
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec1(d1)
        
        d2 = self.up2(d1)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)
        
        return self.seg_head(d2)

# ----------------------
# Semi-Supervised Framework
# ----------------------
class SemiSupervisedModel(nn.Module):
    def __init__(self, num_classes=3, alpha=0.999):
        super().__init__()
        self.student = CervicalUNet(num_classes=num_classes)
        self.teacher = CervicalUNet(num_classes=num_classes)
        self.alpha = alpha
        self._init_teacher()
        
        # Loss parameters
        self.lambda_unsup = 0.1
        self.lambda_sat = 0.1
        self.conf_thresh = 0.65

    def _init_teacher(self):
        with torch.no_grad():
            for t_param, s_param in zip(self.teacher.parameters(), self.student.parameters()):
                t_param.data.copy_(s_param.data)

    @torch.no_grad()
    def update_teacher(self, global_step):
        alpha = min(1 - 1/(global_step/100 + 1), self.alpha)
        for t_param, s_param in zip(self.teacher.parameters(), self.student.parameters()):
            t_param.data.mul_(alpha).add_(s_param.data, alpha=1-alpha)

    def forward(self, x):
        return self.student(x)

# ----------------------
# Loss Functions
# ----------------------
def compute_sat_loss(student_feats, teacher_feats, temperature=0.1):
    student_feats = F.normalize(student_feats, p=2, dim=1)
    teacher_feats = F.normalize(teacher_feats, p=2, dim=1)
    
    sim_matrix = torch.mm(student_feats, teacher_feats.t()) / temperature
    pos_sim = torch.diag(sim_matrix)
    neg_sim = (sim_matrix.sum(dim=1) - pos_sim) / (sim_matrix.size(1) - 1)
    
    loss = -torch.log(torch.exp(pos_sim) / (torch.exp(pos_sim) + torch.exp(neg_sim) + 1e-8)).mean()
    return loss

class AdaptiveLoss(nn.Module):
    def __init__(self, class_weights=None):
        super().__init__()
        self.class_weights = class_weights
        
    def forward(self, preds, targets):
        # Cross-Entropy
        ce_loss = F.cross_entropy(preds, targets, weight=self.class_weights)
        
        # Dice Loss
        smooth = 1e-6
        preds_soft = F.softmax(preds, dim=1)
        targets_oh = F.one_hot(targets, num_classes=preds.shape[1]).permute(0,3,1,2).float()
        
        intersection = (preds_soft * targets_oh).sum(dim=(2,3))
        union = preds_soft.sum(dim=(2,3)) + targets_oh.sum(dim=(2,3))
        dice_loss = 1 - (2. * intersection + smooth) / (union + smooth)
        
        return ce_loss + dice_loss.mean()

# ----------------------
# Training Utilities
# ----------------------
def dice_score(pred, target):
    smooth = 1e-6
    pred = pred.argmax(1)
    return (2.0 * (pred * target).sum() + smooth) / (pred.sum() + target.sum() + smooth)

def train(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Dataset
    labeled_ds = CervicalDataset(args.labeled_img, args.labeled_mask)
    unlabeled_ds = CervicalDataset(args.unlabeled_img)
    
    # Class weights
    class_counts = torch.zeros(3)
    for _, mask in labeled_ds:
        class_counts += torch.bincount(mask.flatten(), minlength=3)
    class_weights = 1.0 / (class_counts / class_counts.sum()).to(device)
    
    # Data loaders
    labeled_loader = DataLoader(labeled_ds, batch_size=args.batch_size, shuffle=True, 
                               num_workers=2, pin_memory=True)
    unlabeled_loader = DataLoader(unlabeled_ds, batch_size=args.batch_size*2, shuffle=True,
                                 num_workers=2, pin_memory=True)
    
    # Model
    model = SemiSupervisedModel(num_classes=3).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    
    # Losses
    sup_criterion = AdaptiveLoss(class_weights)
    unsup_criterion = AdaptiveLoss()
    
    best_loss = float('inf')
    global_step = 0
    
    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        sup_loss_total = 0.0
        unsup_loss_total = 0.0
        dice_total = 0.0
        
        # Dynamic parameters
        current_thresh = 0.65 + min(epoch/args.epochs, 1)*0.25  # 0.65 → 0.9
        current_lambda_unsup = min(epoch/10 * 0.5, 0.5)
        current_lambda_sat = 0.2 * (1 - epoch/args.epochs)
        
        pbar = tqdm(zip(labeled_loader, unlabeled_loader), 
                   total=min(len(labeled_loader), len(unlabeled_loader)), 
                   desc=f"Epoch {epoch+1}/{args.epochs}")
        
        for (labeled_x, labeled_y), unlabeled_x in pbar:
            labeled_x, labeled_y = labeled_x.to(device), labeled_y.to(device)
            unlabeled_x = unlabeled_x.to(device)
            
            # Supervised Forward
            student_preds = model(labeled_x)
            sup_loss = sup_criterion(student_preds, labeled_y)
            
            # Unsupervised Forward
            with torch.no_grad():
                teacher_preds = model.teacher(unlabeled_x)
                pseudo_probs = F.softmax(teacher_preds, dim=1)
                max_probs, pseudo_labels = torch.max(pseudo_probs, dim=1)
                mask = (max_probs > current_thresh).float()
                
            unsup_loss = 0.0
            if mask.sum() > 0:
                student_u_preds = model(unlabeled_x)
                unsup_loss = unsup_criterion(student_u_preds, pseudo_labels) * mask.mean()
            
            # Feature Alignment Loss
            with torch.no_grad():
                t_features = model.teacher.forward_features(unlabeled_x)
            s_features = model.student.forward_features(unlabeled_x)
            sat_loss = compute_sat_loss(s_features, t_features)
            
            # Total Loss
            total_loss = sup_loss + current_lambda_unsup*unsup_loss + current_lambda_sat*sat_loss
            
            # Optimization
            optimizer.zero_grad()
            total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            model.update_teacher(global_step)
            global_step += 1
            
            # Metrics
            sup_loss_total += sup_loss.item()
            unsup_loss_total += unsup_loss.item() if unsup_loss != 0 else 0
            total_loss += total_loss.item()
            dice_total += dice_score(student_preds, labeled_y).item()
            
            pbar.set_postfix({
                "Sup": f"{sup_loss.item():.3f}",
                "Unsup": f"{unsup_loss.item():.3f}" if unsup_loss != 0 else "0.000",
                "SAT": f"{sat_loss.item():.3f}",
                "Dice": f"{dice_total/(pbar.n+1):.3f}",
                "Total": f"{total_loss.item():.3f}"
            })
        
        # Update scheduler
        scheduler.step()
        
        # Save best model
        avg_loss = total_loss / len(labeled_loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{args.save_dir}/best_model.pth")
        
        print(f"Epoch {epoch+1} | Sup: {sup_loss_total/len(labeled_loader):.3f} "
              f"Unsup: {unsup_loss_total/len(unlabeled_loader):.3f} "
              f"Dice: {dice_total/len(labeled_loader):.3f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--labeled_img", type=str, default="dataset/labeled_data/images")
    parser.add_argument("--labeled_mask", type=str, default="dataset/labeled_data/labels")
    parser.add_argument("--unlabeled_img", type=str, default="dataset/unlabeled_data/images")
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--save_dir", type=str, default="checkpoints")
    args = parser.parse_args()
    
    os.makedirs(args.save_dir, exist_ok=True)
    train(args)

usage: ipykernel_launcher.py [-h] [--labeled_img LABELED_IMG]
                             [--labeled_mask LABELED_MASK]
                             [--unlabeled_img UNLABELED_IMG] [--epochs EPOCHS]
                             [--batch_size BATCH_SIZE] [--lr LR]
                             [--save_dir SAVE_DIR]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-f02115d4-53f7-4de5-994e-a1ef8a0bb2de.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
