In [1]:
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

# Common transforms for images
transform = T.Compose([
    T.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    T.ToTensor(),
    # Optionally, normalize here
    # T.Normalize(mean=[0.485, 0.456, 0.406], 
    #             std=[0.229, 0.224, 0.225])
])

# Common transforms for masks
transform_mask = T.Compose([
    T.Resize((256, 256), interpolation=InterpolationMode.NEAREST),
    T.ToTensor()
])

In [2]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class MyLabeledDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, transform_mask=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.transform_mask = transform_mask
        
        self.images = sorted([
            os.path.join(image_dir, x) 
            for x in os.listdir(image_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])
        
        self.masks = sorted([
            os.path.join(mask_dir, x) 
            for x in os.listdir(mask_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")  # or "L" if grayscale

        # Load corresponding mask
        mask_path = self.masks[idx]
        mask = Image.open(mask_path).convert("L")  # single-channel

        if self.transform:
            image = self.transform(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)

        # Binarize the mask (assuming foreground is > 0.5)
        mask = (mask > 0.5).long()  
        mask = mask.squeeze(0)     

        return image, mask

class MyUnlabeledDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = sorted([
            os.path.join(image_dir, x) 
            for x in os.listdir(image_dir) 
            if x.endswith('.png') or x.endswith('.jpg')
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(SimpleUNet, self).__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(64+64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(32+32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.seg_head = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        b = self.bottleneck(p2)

        u2 = self.up2(b)
        cat2 = torch.cat([u2, e2], dim=1)
        d2 = self.dec2(cat2)

        u1 = self.up1(d2)
        cat1 = torch.cat([u1, e1], dim=1)
        d1 = self.dec1(cat1)

        return self.seg_head(d1)  # shape: [N, out_channels, H, W]


In [6]:
import torch
import torch.nn as nn

class UltraSemiNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2, alpha=0.99):
        super(UltraSemiNet, self).__init__()
        # Student and Teacher share the same architecture
        self.student_net = SimpleUNet(in_channels, num_classes)
        self.teacher_net = SimpleUNet(in_channels, num_classes)
        
        # Initialize teacher weights to match student initially
        self._update_teacher(0.0)
        # Exponential moving average factor for teacher update
        self.alpha = alpha
        self.num_classes = num_classes

    def forward(self, x):
        return self.student_net(x)

    @torch.no_grad()
    def _update_teacher(self, alpha=None):
        if alpha is None:
            alpha = self.alpha
        for teacher_param, student_param in zip(self.teacher_net.parameters(), 
                                                self.student_net.parameters()):
            teacher_param.data = alpha * teacher_param.data + (1 - alpha) * student_param.data


In [8]:
import torch
import torch.nn.functional as F

def sat_loss(anchor, pos, neg, temperature=0.07):
    # Dot products for anchor-positive and anchor-negative
    sim_pos = (anchor * pos).sum(dim=1) / temperature
    sim_neg = (anchor * neg).sum(dim=1) / temperature
    
    # SAT Loss is based on softmax cross-entropy
    logits = torch.stack([sim_pos, sim_neg], dim=1)
    labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)
    return F.cross_entropy(logits, labels)

def compute_hardness_map(logits):
    probs = torch.softmax(logits, dim=1)
    ent = -torch.sum(probs * torch.log(probs + 1e-8), dim=1, keepdim=True)
    return ent

def aldc_loss(features, labels, mask, temperature=0.07):
    B, C, H, W = features.shape
    features_2d = features.permute(0,2,3,1).reshape(-1, C) 
    labels_2d = labels.reshape(-1)     
    mask_2d = mask.reshape(-1) > 0.5

    idxs = torch.where(mask_2d)[0]
    if len(idxs) < 2:
        return torch.tensor(0.0, device=features.device)  # no "hard" region

    # Pick random anchor in the masked region
    anchor_idx = idxs[torch.randint(0, len(idxs), (1,))]
    anchor_feat = features_2d[anchor_idx]  # shape (C,)
    anchor_label = labels_2d[anchor_idx]
    same_label_idx = idxs[(labels_2d[idxs] == anchor_label)]
    if len(same_label_idx) < 2:
        return torch.tensor(0.0, device=features.device)
    pos_idx = same_label_idx[torch.randint(0, len(same_label_idx), (1,))]
    pos_feat = features_2d[pos_idx]
    diff_label_idx = idxs[(labels_2d[idxs] != anchor_label)]
    if len(diff_label_idx) < 1:
        return torch.tensor(0.0, device=features.device)
    neg_idx = diff_label_idx[torch.randint(0, len(diff_label_idx), (1,))]
    neg_feat = features_2d[neg_idx]

    # Same form as sat_loss
    sim_pos = (anchor_feat * pos_feat).sum() / temperature
    sim_neg = (anchor_feat * neg_feat).sum() / temperature
    logits = torch.stack([sim_pos, sim_neg], dim=0).unsqueeze(0)
    labels_val = torch.zeros(1, dtype=torch.long, device=features.device)

    return F.cross_entropy(logits, labels_val)


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


def train_ultraseminet(
    student_teacher_model,
    dataloader_labeled,
    dataloader_unlabeled,
    optimizer,
    num_epochs=10,
    temperature=0.07,
    lambda_sat=0.5,
    lambda_aldc=0.5,
    save_path="model.pth"
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    student_teacher_model = student_teacher_model.to(device)
    criterion_ce = nn.CrossEntropyLoss()
    best_loss = float("inf")

    for epoch in range(num_epochs):
        student_teacher_model.train()
        running_loss = 0.0
        steps = 0
        steps_per_epoch = min(len(dataloader_labeled), len(dataloader_unlabeled))

        # Create the tqdm progress bar
        pbar = tqdm(
            zip(dataloader_labeled, dataloader_unlabeled),
            total=steps_per_epoch,
            desc=f"Epoch {epoch+1}/{num_epochs}"
        )
        for (x_l, y_l), x_u in pbar:
            x_l, y_l = x_l.to(device), y_l.to(device)
            x_u = x_u.to(device)

            # Supervised loss
            logits_l = student_teacher_model(x_l)
            sup_loss = criterion_ce(logits_l, y_l)

            # Pseudo-label generation (teacher side)
            with torch.no_grad():
                logits_u_teacher = student_teacher_model.teacher_net(x_u)
                pseudo_labels = torch.argmax(logits_u_teacher, dim=1)
            
            # Student forward on unlabeled data
            logits_u_student = student_teacher_model(x_u)
            unsup_loss_ce = criterion_ce(logits_u_student, pseudo_labels)

            # Features for SAT loss
            features_student_u = F.adaptive_avg_pool2d(logits_u_student, (1,1)).squeeze(-1).squeeze(-1)
            features_teacher_u = F.adaptive_avg_pool2d(logits_u_teacher, (1,1)).squeeze(-1).squeeze(-1)

            # Negative examples by shuffling
            batch_size = features_student_u.size(0)
            indices = torch.randperm(batch_size, device=device)
            neg_features = features_student_u[indices]

            # SAT loss
            sat_loss_val = sat_loss(features_student_u, features_teacher_u, neg_features, temperature)

            # ALDC loss on labeled data (using hardness map from teacher)
            with torch.no_grad():
                logits_l_hard = student_teacher_model.teacher_net(x_l)
            hardness_map = compute_hardness_map(logits_l_hard)
            mask = (hardness_map > 0.5).float()

            aldc_val = aldc_loss(logits_l, y_l.unsqueeze(1), mask, temperature)

            # Total loss
            total_loss = sup_loss + unsup_loss_ce + lambda_sat*sat_loss_val + lambda_aldc*aldc_val

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # EMA update for teacher
            student_teacher_model._update_teacher()

            running_loss += total_loss.item()
            steps += 1

            # Update the progress bar
            pbar.set_postfix({
                "SupLoss": f"{sup_loss.item():.4f}",
                "UnsupLoss": f"{unsup_loss_ce.item():.4f}",
                "SAT": f"{sat_loss_val.item():.4f}",
                "ALDC": f"{aldc_val.item():.4f}",
                "Total": f"{total_loss.item():.4f}"
            })

        epoch_loss = running_loss / steps if steps > 0 else 0.0
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}")

        # Save best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(student_teacher_model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with loss={epoch_loss:.4f}")

    print("Training complete!")


In [None]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
# from ultra_semi_net import UltraSemiNet
# from train import train_ultraseminet
# from datasets import MyLabeledDataset, MyUnlabeledDataset
# from transformations import transform, transform_mask

labeled_image_dir = '/home/ufaqkhan/UltraSemiNet/Dataset/labeled/original'
labeled_mask_dir = '/home/ufaqkhan/UltraSemiNet/Dataset/labeled/groundtruth'
unlabeled_image_dir = '/home/ufaqkhan/UltraSemiNet/Dataset/unlabeled'
labeled_dataset = MyLabeledDataset(
    image_dir=labeled_image_dir,
    mask_dir=labeled_mask_dir,
    transform=transform,
    transform_mask=transform_mask
)
unlabeled_dataset = MyUnlabeledDataset(
    image_dir=unlabeled_image_dir,
    transform=transform
)
labeled_loader = DataLoader(labeled_dataset, batch_size=4, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=4, shuffle=True)
model = UltraSemiNet(in_channels=3, num_classes=2, alpha=0.99)  # or in_channels=1 if grayscale
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_ultraseminet(
    student_teacher_model=model,
    dataloader_labeled=labeled_loader,
    dataloader_unlabeled=unlabeled_loader,
    optimizer=optimizer,
    num_epochs=10,
    temperature=0.07,
    lambda_sat=0.5,
    lambda_aldc=0.5,
    save_path="model.pth"
)
