In [17]:
import os
from pathlib import Path
import math
import random
import numpy as np
from PIL import Image, ImageOps, ImageEnhance

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from torch.cuda.amp import GradScaler, autocast

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
from torchvision.models import resnet18, resnet34, resnet50
import timm
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

#####################################
# 1. Custom Dataset with Augmentations
#####################################
class SmallDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, return_idx=False):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.return_idx = return_idx
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        if self.return_idx:
            return image, label, idx
        return image, label

#####################################
# 2. Strong Augmentation Pipelines
#####################################
class RandAugment:
    def __init__(self, n=2, m=9):
        self.n = n  # Number of augmentation transformations to apply
        self.m = m  # Magnitude for all transformations
        self.augment_list = [
            self.auto_contrast,
            self.equalize,
            self.rotate,
            self.solarize,
            self.color,
            self.contrast,
            self.brightness,
            self.sharpness,
            self.shear_x,
            self.shear_y,
            self.translate_x,
            self.translate_y,
        ]
            
    def auto_contrast(self, img, magnitude):
        return ImageOps.autocontrast(img)
    
    def equalize(self, img, magnitude):
        return ImageOps.equalize(img)
    
    def rotate(self, img, magnitude):
        magnitude = (magnitude / 10) * 30
        return img.rotate(magnitude)
    
    def solarize(self, img, magnitude):
        magnitude = (magnitude / 10) * 256
        return ImageOps.solarize(img, magnitude)
    
    def color(self, img, magnitude):
        magnitude = (magnitude / 10) * 1.8 + 0.1
        return ImageEnhance.Color(img).enhance(magnitude)
    
    def contrast(self, img, magnitude):
        magnitude = (magnitude / 10) * 1.8 + 0.1
        return ImageEnhance.Contrast(img).enhance(magnitude)
    
    def brightness(self, img, magnitude):
        magnitude = (magnitude / 10) * 1.8 + 0.1
        return ImageEnhance.Brightness(img).enhance(magnitude)
    
    def sharpness(self, img, magnitude):
        magnitude = (magnitude / 10) * 1.8 + 0.1
        return ImageEnhance.Sharpness(img).enhance(magnitude)
    
    def shear_x(self, img, magnitude):
        magnitude = (magnitude / 10) * 0.3
        return img.transform(img.size, Image.AFFINE, (1, magnitude, 0, 0, 1, 0))
    
    def shear_y(self, img, magnitude):
        magnitude = (magnitude / 10) * 0.3
        return img.transform(img.size, Image.AFFINE, (1, 0, 0, magnitude, 1, 0))
    
    def translate_x(self, img, magnitude):
        magnitude = (magnitude / 10) * float(img.size[0] / 3)
        return img.transform(img.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0))
    
    def translate_y(self, img, magnitude):
        magnitude = (magnitude / 10) * float(img.size[1] / 3)
        return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude))
    
    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)
        for op in ops:
            img = op(img, self.m)
        return img

def get_strong_augmentation(img_size=224):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.RandomApply([transforms.RandomRotation(20)], p=0.5),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2)], p=0.5),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2),
        transforms.RandomPerspective(distortion_scale=0.3, p=0.3),
        RandAugment(n=2, m=9),
        transforms.RandomApply([transforms.RandomErasing(p=0.3, scale=(0.02, 0.33), ratio=(0.3, 3.3))], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def get_test_augmentation(img_size=224):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

#####################################
# 3. Test-Time Augmentation
#####################################
class TTAWrapper(nn.Module):
    def __init__(self, model, tta_transforms):
        super(TTAWrapper, self).__init__()
        self.model = model
        self.tta_transforms = tta_transforms
        
    def forward(self, x):
        batch_size = x.size(0)
        outputs = []
        
        for transform in self.tta_transforms:
            transformed_x = torch.stack([transform(img) for img in x])
            output = self.model(transformed_x)
            outputs.append(output)
            
        # Average predictions
        outputs = torch.stack(outputs, dim=0)
        return outputs.mean(dim=0)

def get_tta_transforms(img_size=224):
    tta_transforms = [
        transforms.Compose([
            transforms.Lambda(lambda x: x),  # Identity
        ]),
        transforms.Compose([
            transforms.Lambda(lambda x: TF.hflip(x)),
        ]),
        transforms.Compose([
            transforms.Lambda(lambda x: TF.vflip(x)),
        ]),
        transforms.Compose([
            transforms.Lambda(lambda x: TF.rotate(x, 90)),
        ]),
        transforms.Compose([
            transforms.Lambda(lambda x: TF.rotate(x, 180)),
        ])
    ]
    return tta_transforms

#####################################
# 4. Mixup and CutMix Implementation
#####################################
def mixup_data(x, y, alpha=0.8):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=0.8):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x_mixed = x.clone()
    x_mixed[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    # Adjust lambda to match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    y_a, y_b = y, y[index]
    
    return x_mixed, y_a, y_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

#####################################
# 5. Progressive Layer Unfreezing
#####################################
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def unfreeze_layers(model, layer_name):
    for name, param in model.named_parameters():
        if layer_name in name:
            param.requires_grad = True

def unfreeze_model_stages(model, model_type, stages=None):
    """
    Unfreeze specific stages of ResNet or DeiT
    """
    if stages is None:
        # Unfreeze all
        for param in model.parameters():
            param.requires_grad = True
        return
    
    # First freeze all
    for param in model.parameters():
        param.requires_grad = False
        
    if model_type == 'resnet':
        # ResNet stages
        if 'head' in stages:
            for param in model.fc.parameters():
                param.requires_grad = True
        if 'layer4' in stages:
            for param in model.layer4.parameters():
                param.requires_grad = True
        if 'layer3' in stages:
            for param in model.layer3.parameters():
                param.requires_grad = True
        if 'layer2' in stages:
            for param in model.layer2.parameters():
                param.requires_grad = True
        if 'layer1' in stages:
            for param in model.layer1.parameters():
                param.requires_grad = True
    
    elif model_type == 'deit':
        # DeiT stages
        if 'head' in stages:
            for param in model.head.parameters():
                param.requires_grad = True
        if 'blocks' in stages:
            # Determine which blocks to unfreeze
            num_blocks = len(model.blocks)
            blocks_to_unfreeze = []
            if 'blocks_last_3' in stages:
                blocks_to_unfreeze.extend(range(num_blocks-3, num_blocks))
            if 'blocks_last_6' in stages:
                blocks_to_unfreeze.extend(range(num_blocks-6, num_blocks-3))
            if 'blocks_all' in stages:
                blocks_to_unfreeze.extend(range(num_blocks))
            
            # Unfreeze specified blocks
            for i in blocks_to_unfreeze:
                for param in model.blocks[i].parameters():
                    param.requires_grad = True

#####################################
# 6. Knowledge Distillation
#####################################
class DistillationLoss(nn.Module):
    def __init__(self, base_criterion, teacher_model, distillation_alpha=0.5, distillation_tau=3.0):
        super().__init__()
        self.base_criterion = base_criterion
        self.teacher_model = teacher_model
        self.distillation_alpha = distillation_alpha
        self.distillation_tau = distillation_tau
    
    def forward(self, inputs, outputs, targets):
        # Compute base loss
        base_loss = self.base_criterion(outputs, targets)
        
        # Compute distillation loss (KL divergence)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs)
        
        # Compute soft targets
        # Apply temperature scaling
        T = self.distillation_tau
        soft_targets = F.softmax(teacher_outputs / T, dim=-1)
        soft_outputs = F.log_softmax(outputs / T, dim=-1)
        distillation_loss = F.kl_div(soft_outputs, soft_targets, reduction='batchmean') * (T * T)
        
        # Combine losses
        loss = base_loss * (1 - self.distillation_alpha) + distillation_loss * self.distillation_alpha
        return loss

#####################################
# 7. Create Models with Modifications
#####################################
def create_resnet_model(num_classes, model_type='resnet18', pretrained=True, feature_extract=True):
    if model_type == 'resnet18':
        model = resnet18(pretrained=pretrained)
    elif model_type == 'resnet34':
        model = resnet34(pretrained=pretrained)
    elif model_type == 'resnet50':
        model = resnet50(pretrained=pretrained)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    
    # Freeze all layers if feature extracting
    if feature_extract:
        set_parameter_requires_grad(model, feature_extracting=True)
    
    # Modify the final fully connected layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, num_classes)
    )
    
    return model

def create_deit_model(num_classes, model_type='deit_tiny_patch16_224', pretrained=True, feature_extract=True):
    model = timm.create_model(model_type, pretrained=pretrained)
    
    # Freeze all layers if feature extracting
    if feature_extract:
        set_parameter_requires_grad(model, feature_extracting=True)
    
    # Modify the head
    num_ftrs = model.head.in_features
    model.head = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, num_classes)
    )
    
    return model

#####################################
# 8. Cross-Validation Implementation
#####################################
def k_fold_cross_validation(dataset, model_fn, num_classes, k=5, batch_size=32, num_epochs=100, learning_rate=0.001, 
                           weight_decay=0.01, device='cuda', model_type='resnet'):
    
    kfold = KFold(n_splits=k, shuffle=True, random_state=42)
    fold_results = []
    
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        print(f"Fold {fold+1}/{k}")
        
        # Create data samplers
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        
        # Create data loaders
        train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
        val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
        
        # Create the model
        model = model_fn(num_classes=num_classes)
        model = model.to(device)
        
        # Create optimizer and scheduler
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
        
        criterion = nn.CrossEntropyLoss(label_smoothing=0.2)
        
        # Train the model
        best_acc = 0.0
        best_model_state = None
        
        for epoch in range(num_epochs):
            # Progressive unfreezing
            if epoch == 0:
                if model_type == 'resnet':
                    unfreeze_model_stages(model, 'resnet', ['head'])
                else:
                    unfreeze_model_stages(model, 'deit', ['head'])
            elif epoch == 20:
                if model_type == 'resnet':
                    unfreeze_model_stages(model, 'resnet', ['head', 'layer4'])
                else:
                    unfreeze_model_stages(model, 'deit', ['head', 'blocks_last_3'])
            elif epoch == 40:
                if model_type == 'resnet':
                    unfreeze_model_stages(model, 'resnet', ['head', 'layer4', 'layer3'])
                else:
                    unfreeze_model_stages(model, 'deit', ['head', 'blocks_last_3', 'blocks_last_6'])
            
            # Training phase
            model.train()
            running_loss = 0.0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Apply mixup or cutmix
                if random.random() < 0.5:
                    inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=0.8)
                elif random.random() < 0.5:
                    inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, alpha=0.8)
                else:
                    labels_a, labels_b, lam = labels, labels, 1.0
                
                optimizer.zero_grad()
                
                with autocast():
                    outputs = model(inputs)
                    if lam != 1.0:
                        loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                    else:
                        loss = criterion(outputs, labels)
                
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
            
            # Validation phase
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            epoch_acc = correct / total
            print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs}, Validation Accuracy: {epoch_acc:.4f}")
            
            scheduler.step()
            
            # Save best model
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_state = model.state_dict().copy()
        
        fold_results.append(best_acc)
        print(f"Fold {fold+1} best accuracy: {best_acc:.4f}")
        
        # Save the model
        torch.save(best_model_state, f"model_fold_{fold+1}.pth")
    
    # Print overall results
    print(f"K-fold cross-validation results: {fold_results}")
    print(f"Average accuracy: {sum(fold_results)/len(fold_results):.4f}")
    
    return fold_results

#####################################
# 9. Semi-Supervised Learning (Pseudo-Labeling)
#####################################
def pseudo_labeling(labeled_dataset, unlabeled_paths, model, device, confidence_threshold=0.85):
    model.eval()
    transform = get_test_augmentation()
    pseudo_labeled_data = []
    
    with torch.no_grad():
        for img_path in unlabeled_paths:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(device)
            outputs = model(img_tensor)
            probs = F.softmax(outputs, dim=1)
            confidence, predicted = torch.max(probs, 1)
            
            if confidence.item() > confidence_threshold:
                pseudo_labeled_data.append((img_path, predicted.item()))
    
    # Create new dataset with both labeled and pseudo-labeled data
    all_paths = [p for p, _ in labeled_dataset.image_paths] + [p for p, _ in pseudo_labeled_data]
    all_labels = labeled_dataset.labels + [l for _, l in pseudo_labeled_data]
    
    return all_paths, all_labels

#####################################
# 10. Self-Supervised Learning with SimCLR
#####################################
class ProjectionHead(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_features, out_features)
        )
    
    def forward(self, x):
        return self.projection(x)

class SimCLR(nn.Module):
    def __init__(self, backbone, projection_dim=128):
        super(SimCLR, self).__init__()
        self.backbone = backbone
        self.backbone.fc = nn.Identity()  # Remove classifier
        
        # Add projection head
        self.projection_head = ProjectionHead(
            in_features=self.backbone.inplanes,
            hidden_features=self.backbone.inplanes,
            out_features=projection_dim
        )
    
    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections

def nt_xent_loss(z1, z2, temperature=0.5):
    """
    Normalized Temperature-scaled Cross Entropy Loss from SimCLR paper
    """
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    batch_size = z1.shape[0]
    representations = torch.cat([z1, z2], dim=0)
    
    # Calculate similarity matrix
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
    
    # Create masks for positive pairs
    sim_i_j = torch.diag(similarity_matrix, batch_size)
    sim_j_i = torch.diag(similarity_matrix, -batch_size)
    
    # We need to remove the diagonal elements from the similarity matrix 
    # to get only negative samples
    mask = torch.ones_like(similarity_matrix) - torch.eye(2 * batch_size, device=z1.device)
    
    # Filter out the positives from the negative samples
    pos_mask = torch.zeros(2 * batch_size, 2 * batch_size, device=z1.device)
    pos_mask[:batch_size, batch_size:] = torch.eye(batch_size, device=z1.device)
    pos_mask[batch_size:, :batch_size] = torch.eye(batch_size, device=z1.device)
    mask = mask * (1 - pos_mask)
    
    # Get negative samples
    neg_sim = similarity_matrix * mask
    
    # For numerical stability
    neg_sim = neg_sim.view(2 * batch_size, -1)
    
    # Positive pairs
    pos_sim = torch.cat([sim_i_j.unsqueeze(1), sim_j_i.unsqueeze(1)], dim=0)
    
    # Logits: [2*batch_size, 1+2*batch_size-2]
    logits = torch.cat([pos_sim, neg_sim], dim=1) / temperature
    
    # Labels: positives are the first column
    labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z1.device)
    
    return F.cross_entropy(logits, labels)

def simclr_transform(img_size=224):
    """
    SimCLR augmentation pipeline
    """
    return transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

class TwoCropTransform:
    """
    Create two crops of the same image
    """
    def __init__(self, transform):
        self.transform = transform
    
    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

def train_simclr(model, dataloader, optimizer, device, epochs=100):
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        
        for images, _ in dataloader:
            # Get the two views
            img1 = images[0].to(device)
            img2 = images[1].to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            z1 = model(img1)
            z2 = model(img2)
            
            # Compute loss
            loss = nt_xent_loss(z1, z2)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader):.4f}")
    
    return model

def load_image_dataset(data_directory):
    """
    Load images and labels from a directory structure where:
    data/
    ├── Class1/
    │   ├── image1.jpg
    │   └── image2.jpg
    └── Class2/
        └── image3.jpg
    
    Returns:
    - image_paths: List of paths to all images
    - labels: List of corresponding labels for each image
    """
    image_paths = []
    labels = []
    
    # Get the absolute path to ensure consistent behavior
    data_dir = Path(data_directory)
    
    # Supported image extensions
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
    
    # Walk through the data directory
    for class_name in os.listdir(data_dir):
        class_dir = data_dir / class_name
        
        # Skip if not a directory
        if not class_dir.is_dir():
            continue
        
        # Process all image files in the class directory
        for file_name in os.listdir(class_dir):
            file_path = class_dir / file_name
            
            # Check if it's a valid image file
            if file_path.is_file() and file_path.suffix.lower() in image_extensions:
                # Add the image path and its label
                image_paths.append(str(file_path))
                labels.append(class_name)
    
    return image_paths, labels

#####################################
# 11. Full Training Pipeline - Main Functions
#####################################
def full_training_pipeline(data_dir, num_classes, device='cuda'):
    # 1. Load dataset
    print("Loading dataset...")
    # Implement your data loading logic here to get image_paths and labels
    # For this example:
    # image_pathss = [os.path.join(data_dir, f) for f in os.listdir(data_dir) ]
    # # Assuming equal distribution of classes for demonstration
    # labelss = [i % num_classes for i in range(len(image_pathss))]
    # print(image_pathss)
    image_paths, labels = load_image_dataset("./data")
    
    # 2. Data augmentation
    print("Setting up data augmentation...")
    train_transform = get_strong_augmentation()
    val_transform = get_test_augmentation()
    
    # 3. Split data into train and validation
    train_size = int(0.8 * len(image_paths))
    val_size = len(image_paths) - train_size
    
    train_dataset = SmallDataset(image_paths[:train_size], labels[:train_size], transform=train_transform)
    val_dataset = SmallDataset(image_paths[train_size:], labels[train_size:], transform=val_transform)
    
    # 4. Set up dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
    
    # 5. Optional: Self-supervised pretraining with SimCLR
    print("Starting self-supervised pretraining...")
    
    # Create base model for SimCLR
    base_model = resnet18(pretrained=True)
    simclr_model = SimCLR(backbone=base_model).to(device)
    
    # Create SimCLR dataset with two-crop transform
    simclr_transforms = TwoCropTransform(simclr_transform())
    simclr_dataset = SmallDataset(image_paths, labels, transform=simclr_transforms)
    simclr_loader = DataLoader(simclr_dataset, batch_size=16, shuffle=True, num_workers=0)
    
    # Train with SimCLR
    simclr_optimizer = optim.AdamW(simclr_model.parameters(), lr=0.0003, weight_decay=0.1)
    simclr_model = train_simclr(simclr_model, simclr_loader, simclr_optimizer, device, epochs=50)
    
    # 6. Create models for supervised fine-tuning
    print("Creating supervised models...")
    
    # Create ResNet model (using pretrained weights from SimCLR)
    resnet_model = create_resnet_model(num_classes, model_type='resnet18', pretrained=True, feature_extract=True)
    # Load SimCLR backbone weights
    state_dict = simclr_model.backbone.state_dict()
    resnet_model.load_state_dict(state_dict, strict=False)
    resnet_model = resnet_model.to(device)
    
    # Create DeiT model
    deit_model = create_deit_model(num_classes, model_type='deit_tiny_patch16_224', pretrained=True, feature_extract=True)
    deit_model = deit_model.to(device)
    
    # 7. Create teacher model for knowledge distillation
    teacher_model = resnet50(pretrained=True)
    num_ftrs = teacher_model.fc.in_features
    teacher_model.fc = nn.Linear(num_ftrs, num_classes)
    teacher_model = teacher_model.to(device)
    teacher_model.eval()  # Set to evaluation mode
    
    # 8. Set up optimizers with lower learning rates
    resnet_optimizer = optim.AdamW(resnet_model.parameters(), lr=0.0001, weight_decay=0.01)
    deit_optimizer = optim.AdamW(deit_model.parameters(), lr=0.0001, weight_decay=0.01)
    
    # 9. Set up learning rate schedulers
    resnet_scheduler = CosineAnnealingLR(resnet_optimizer, T_max=100)
    deit_scheduler = CosineAnnealingLR(deit_optimizer, T_max=100)
    
    # 10. Set up loss functions with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=0.2)
    
    # Distillation loss for the DeiT model
    distillation_loss = DistillationLoss(
        base_criterion=criterion,
        teacher_model=teacher_model,
        distillation_alpha=0.5,
        distillation_tau=3.0
    )
    
    # 11. Set up gradient scaler for mixed precision training
    scaler = GradScaler()
    
    # 12. Training loops for both models
    best_resnet_acc = 0.0
    best_deit_acc = 0.0
    best_resnet_model = None
    best_deit_model = None
    patience = 20
    resnet_patience_counter = 0
    deit_patience_counter = 0
    
    print("Starting ResNet training...")
    # ResNet training loop
    for epoch in range(100):  # 100 epochs
        # Progressive unfreezing for ResNet
        if epoch == 0:
            unfreeze_model_stages(resnet_model, 'resnet', ['head'])
        elif epoch == 20:
            unfreeze_model_stages(resnet_model, 'resnet', ['head', 'layer4'])
        elif epoch == 40:
            unfreeze_model_stages(resnet_model, 'resnet', ['head', 'layer4', 'layer3'])
        elif epoch == 60:
            unfreeze_model_stages(resnet_model, 'resnet', ['head', 'layer4', 'layer3', 'layer2'])
        
        # Training phase
        resnet_model.train()
        train_loss = 0.0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Apply mixup or cutmix with high probability
            r = random.random()
            if r < 0.4:  # 40% chance of mixup
                inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=0.8)
            elif r < 0.8:  # 40% chance of cutmix
                inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, alpha=0.8)
            else:  # 20% no augmentation
                labels_a, labels_b, lam = labels, labels, 1.0
            
            # Zero the parameter gradients
            resnet_optimizer.zero_grad()
            
            # Mixed precision forward pass
            with autocast():
                outputs = resnet_model(inputs)
                if lam != 1.0:
                    loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                else:
                    loss = criterion(outputs, labels)
            
            # Backward and optimize with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(resnet_optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
        
        # Calculate epoch loss
        epoch_loss = train_loss / len(train_loader.dataset)
        
        # Validation phase
        resnet_model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = resnet_model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_acc = val_correct / val_total
        print(f"Epoch {epoch+1}/100, ResNet Loss: {epoch_loss:.4f}, Accuracy: {val_acc:.4f}")
        
        # Update learning rate
        resnet_scheduler.step()
        
        # Early stopping
        if val_acc > best_resnet_acc:
            best_resnet_acc = val_acc
            best_resnet_model = resnet_model.state_dict().copy()
            resnet_patience_counter = 0
        else:
            resnet_patience_counter += 1
            if resnet_patience_counter >= patience:
                print(f"Early stopping ResNet training at epoch {epoch+1}")
                break
    
    print(f"Best ResNet validation accuracy: {best_resnet_acc:.4f}")
    
    # Save the best ResNet model
    torch.save(best_resnet_model, "best_resnet_model.pth")
    
    print("Starting DeiT training...")
    # DeiT training loop
    for epoch in range(100):  # 100 epochs
        # Progressive unfreezing for DeiT
        if epoch == 0:
            unfreeze_model_stages(deit_model, 'deit', ['head'])
        elif epoch == 20:
            unfreeze_model_stages(deit_model, 'deit', ['head', 'blocks_last_3'])
        elif epoch == 40:
            unfreeze_model_stages(deit_model, 'deit', ['head', 'blocks_last_3', 'blocks_last_6'])
        
        # Training phase
        deit_model.train()
        train_loss = 0.0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Apply mixup or cutmix with high probability
            r = random.random()
            if r < 0.4:  # 40% chance of mixup
                inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha=0.8)
            elif r < 0.8:  # 40% chance of cutmix
                inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, alpha=0.8)
            else:  # 20% no augmentation
                labels_a, labels_b, lam = labels, labels, 1.0
            
            # Zero the parameter gradients
            deit_optimizer.zero_grad()
            
            # Mixed precision forward pass
            with autocast():
                outputs = deit_model(inputs)
                # Use distillation loss
                if lam != 1.0:
                    # For mixed samples, use base criterion with mixup
                    loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
                else:
                    # For regular samples, use distillation loss
                    loss = distillation_loss(inputs, outputs, labels)
            
            # Backward and optimize with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(deit_optimizer)
            scaler.update()
            
            train_loss += loss.item() * inputs.size(0)
        
        # Calculate epoch loss
        epoch_loss = train_loss / len(train_loader.dataset)
        
        # Validation phase
        deit_model.eval()
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = deit_model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_acc = val_correct / val_total
        print(f"Epoch {epoch+1}/100, DeiT Loss: {epoch_loss:.4f}, Accuracy: {val_acc:.4f}")
        
        # Update learning rate
        deit_scheduler.step()
        
        # Early stopping
        if val_acc > best_deit_acc:
            best_deit_acc = val_acc
            best_deit_model = deit_model.state_dict().copy()
            deit_patience_counter = 0
        else:
            deit_patience_counter += 1
            if deit_patience_counter >= patience:
                print(f"Early stopping DeiT training at epoch {epoch+1}")
                break
    
    print(f"Best DeiT validation accuracy: {best_deit_acc:.4f}")
    
    # Save the best DeiT model
    torch.save(best_deit_model, "best_deit_model.pth")
    
    # 13. Pseudo-labeling to leverage unlabeled data
    print("Starting pseudo-labeling...")
    
    # Load the best ResNet model for pseudo-labeling
    resnet_model.load_state_dict(best_resnet_model)
    resnet_model.eval()
    
    # Assuming there's a directory with unlabeled images
    unlabeled_dir = os.path.join(data_dir, 'unlabeled')
    if os.path.exists(unlabeled_dir):
        unlabeled_paths = [os.path.join(unlabeled_dir, f) for f in os.listdir(unlabeled_dir) 
                           if f.endswith(('.jpg', '.png', '.jpeg'))]
        
        # Get pseudo-labels for unlabeled data
        all_paths, all_labels = pseudo_labeling(
            labeled_dataset=train_dataset,
            unlabeled_paths=unlabeled_paths,
            model=resnet_model,
            device=device,
            confidence_threshold=0.85
        )
        
        # Create new dataset with pseudo-labeled data
        combined_dataset = SmallDataset(all_paths, all_labels, transform=train_transform)
        
        # Re-train models with the extended dataset
        # This is a simplified version; you may want to implement a full training loop
        print(f"Re-training with {len(all_paths)} images (original + pseudo-labeled)")
        # ... implement re-training here ...
    
    # 14. Test-time augmentation
    print("Evaluating with test-time augmentation...")
    
    # Load best models
    resnet_model.load_state_dict(best_resnet_model)
    deit_model.load_state_dict(best_deit_model)
    
    # Create TTA wrappers
    tta_transforms = get_tta_transforms()
    resnet_tta = TTAWrapper(resnet_model, tta_transforms).to(device)
    deit_tta = TTAWrapper(deit_model, tta_transforms).to(device)
    
    # Evaluate with TTA
    resnet_tta.eval()
    deit_tta.eval()
    
    resnet_correct = 0
    deit_correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # ResNet predictions with TTA
            resnet_outputs = resnet_tta(inputs)
            _, resnet_predicted = torch.max(resnet_outputs, 1)
            
            # DeiT predictions with TTA
            deit_outputs = deit_tta(inputs)
            _, deit_predicted = torch.max(deit_outputs, 1)
            
            total += labels.size(0)
            resnet_correct += (resnet_predicted == labels).sum().item()
            deit_correct += (deit_predicted == labels).sum().item()
    
    resnet_acc = resnet_correct / total
    deit_acc = deit_correct / total
    
    print(f"ResNet TTA Accuracy: {resnet_acc:.4f}")
    print(f"DeiT TTA Accuracy: {deit_acc:.4f}")
    
    # 15. Ensemble predictions
    print("Evaluating ensemble model...")
    
    ensemble_correct = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Get predictions from both models
            resnet_outputs = resnet_tta(inputs)
            deit_outputs = deit_tta(inputs)
            
            # Average the predictions (weighted ensemble)
            ensemble_outputs = 0.6 * resnet_outputs + 0.4 * deit_outputs
            _, ensemble_predicted = torch.max(ensemble_outputs, 1)
            
            ensemble_correct += (ensemble_predicted == labels).sum().item()
    
    ensemble_acc = ensemble_correct / total
    print(f"Ensemble Accuracy: {ensemble_acc:.4f}")
    
    # Return best models and accuracies
    return {
        'resnet_model': best_resnet_model,
        'deit_model': best_deit_model,
        'resnet_acc': resnet_acc,
        'deit_acc': deit_acc,
        'ensemble_acc': ensemble_acc
    }

# Example usage
def main():
    data_dir = "data"
    num_classes = 3  # Adjust based on your dataset
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Set random seed for reproducibility
    set_seed(42)
    
    # Run the full pipeline
    results = full_training_pipeline(data_dir, num_classes, device)
    
    print("\nFinal Results:")
    print(f"ResNet Accuracy: {results['resnet_acc']:.4f}")
    print(f"DeiT Accuracy: {results['deit_acc']:.4f}")
    print(f"Ensemble Accuracy: {results['ensemble_acc']:.4f}")

if __name__ == "__main__":
    main()

Loading dataset...
Setting up data augmentation...
Starting self-supervised pretraining...
Epoch 1/50, Loss: 2.4413
Epoch 2/50, Loss: 2.2131
Epoch 3/50, Loss: 2.1813
Epoch 4/50, Loss: 2.1443
Epoch 5/50, Loss: 2.1240
Epoch 6/50, Loss: 2.0516
Epoch 7/50, Loss: 2.0532
Epoch 8/50, Loss: 2.0435
Epoch 9/50, Loss: 2.0527
Epoch 10/50, Loss: 2.0554
Epoch 11/50, Loss: 2.0440
Epoch 12/50, Loss: 2.0320


KeyboardInterrupt: 