In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import os
from ultralytics import YOLO

def segment_lesion_auto(image_rgb, min_lesion_frac=0.005):
    gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)
    inv = 255 - gray
    _, th = cv2.threshold(inv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    th = cv2.medianBlur(th, 5)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=1)
    th = cv2.morphologyEx(th, cv2.MORPH_OPEN,  kernel, iterations=1)
    num, labels = cv2.connectedComponents(th)
    
    if num > 1:
        areas = [np.sum(labels == i) for i in range(1, num)]
        max_lbl = 1 + int(np.argmax(areas))
        mask = (labels == max_lbl)
    else:
        mask = th.astype(bool)
        
    if mask.mean() < min_lesion_frac:
        mask[:] = False
    return mask

def mask_by_region_auto(image_rgb, lesion_ratio=0.6, bg_ratio=0.4, mask_value=0, seed=None):
    assert image_rgb.ndim == 3 and image_rgb.shape[2] == 3
    H, W, _ = image_rgb.shape
    lesion_mask = segment_lesion_auto(image_rgb)
    
    lesion_idx = np.flatnonzero(lesion_mask.ravel())
    bg_idx = np.flatnonzero(~lesion_mask.ravel())
    
    rng = np.random.default_rng(seed)
    n_lesion = int(min(len(lesion_idx), max(0, round(lesion_ratio * len(lesion_idx)))))
    n_bg     = int(min(len(bg_idx),     max(0, round(bg_ratio     * len(bg_idx)))))
    
    pick_lesion = rng.choice(lesion_idx, size=n_lesion, replace=False) if n_lesion > 0 else np.array([], dtype=int)
    pick_bg     = rng.choice(bg_idx,     size=n_bg,     replace=False) if n_bg > 0     else np.array([], dtype=int)
    pick_all    = np.concatenate([pick_lesion, pick_bg])
    
    masked_img = image_rgb.copy()
    if isinstance(mask_value, (tuple, list)) and len(mask_value) == 3:
        masked_img.reshape(-1, 3)[pick_all] = mask_value
    else:
        masked_img.reshape(-1, 3)[pick_all] = (mask_value, mask_value, mask_value)
    return masked_img

class DSMAEDataset(Dataset):
    def __init__(self, img_dir, img_size=224):
        self.img_size = img_size
        self.img_paths = []
        exts = {'jpg', 'jpeg', 'png', 'bmp'}
        
        for root, _, files in os.walk(img_dir):
            for file in files:
                if file.split('.')[-1].lower() in exts:
                    self.img_paths.append(os.path.join(root, file))
        
        if len(self.img_paths) == 0:
            raise ValueError(f"No images found in {img_dir}")

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        bgr = cv2.imread(img_path)
        
        if bgr is None:
            return torch.zeros(3, self.img_size, self.img_size), torch.zeros(3, self.img_size, self.img_size)
            
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = cv2.resize(rgb, (self.img_size, self.img_size))
        
        masked_img_np = mask_by_region_auto(rgb, lesion_ratio=0.8, bg_ratio=0.1, mask_value=0)
        
        original_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
        masked_tensor = torch.from_numpy(masked_img_np).permute(2, 0, 1).float() / 255.0
        
        return masked_tensor, original_tensor

class DS_MAE(nn.Module):
    def __init__(self, yolo_model_path='yolo11s.pt'):
        super(DS_MAE, self).__init__()
        full_yolo = YOLO(yolo_model_path)
        
        # Extract backbone layers (0-9)
        original_layers = list(full_yolo.model.model.children())
        self.encoder = nn.Sequential(*original_layers[:10])
        
        for param in self.encoder.parameters():
            param.requires_grad = True

        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)
            enc_out = self.encoder(dummy)
            self.enc_channels = enc_out.shape[1] 

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.enc_channels, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.encoder(x)
        return self.decoder(features)
    
    def get_backbone(self):
        return self.encoder

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 16
    NUM_EPOCHS = 3
    IMAGE_DIR = "/home/hank52052/Dataset/isic/HAM10000/yolo_format/train"

    model = DS_MAE("yolo11s.pt").to(device)
    
    if os.path.exists(IMAGE_DIR):
        dataset = DSMAEDataset(IMAGE_DIR)
        
        if len(dataset) > 0:
            dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
            criterion = nn.MSELoss()
            optimizer = optim.Adadelta(model.parameters(), lr=1.0)
            scheduler = optim.lr_scheduler.CyclicLR(
                optimizer, 
                base_lr=1e-5, 
                max_lr=1e-1, 
                step_size_up=200, 
                mode='triangular',
                cycle_momentum=False
            )

            for epoch in range(NUM_EPOCHS):
                model.train()
                total_loss = 0.0
                
                for i, (masked, original) in enumerate(dataloader):
                    masked, original = masked.to(device), original.to(device)
                    
                    optimizer.zero_grad()
                    outputs = model(masked)
                    loss = criterion(outputs, original)
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    
                    total_loss += loss.item()

                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Loss: {total_loss/len(dataloader):.4f}")

            torch.save(model.get_backbone().state_dict(), "ds_mae_backbone_weights.pth")

In [None]:
from ultralytics import YOLO
import torch
import torch.nn as nn

ds_mae_weights_path = "ds_mae_backbone_weights.pth" 
custom_yaml_path = 'yolo11s-cls-shuffle.yaml'

model = YOLO(custom_yaml_path)

print(f"[INFO] Loading DS-MAE pretrained backbone weights from: {ds_mae_weights_path}...")

ds_mae_state_dict = torch.load(ds_mae_weights_path)
current_model_dict = model.model.state_dict()
new_state_dict = {}

for k, v in ds_mae_state_dict.items():
    target_key = f"model.{k}"
    
    if target_key in current_model_dict:
        if current_model_dict[target_key].shape == v.shape:
            new_state_dict[target_key] = v
        else:
            print(f"[WARNING] Skipping layer due to shape mismatch: {target_key}")
    else:
        print(f"[WARNING] Skipping unknown layer: {target_key}")

model.model.load_state_dict(new_state_dict, strict=False)
print(f"[INFO] Successfully loaded {len(new_state_dict)} layers from DS-MAE backbone.")

results = model.train(
    data="/home/hank52052/Dataset/isic/HAM10000/Four_Classes",
    project="/home/hank52052/code/isic/yolo_training",
    epochs=1200,
    imgsz=256,
    batch=64,
    patience=40,
    plots=True,
    
    # Augmentation Hyperparameters
    degrees=0.95, 
    scale=0.9, 
    shear=0.8, 
    flipud=0.9, 
    fliplr=0.9,
    hsv_h=0.8, 
    hsv_s=0.8, 
    hsv_v=0.9, 
    translate=0.8
)

In [None]:
from ultralytics import YOLO
import torch

custom_yaml_path = 'yolo11s-cls-shuffle.yaml'
model = YOLO(custom_yaml_path)

stage1_weights_path = "/home/hank52052/code/isic/yolo_training/train/weights/best.pt"

print(f"[INFO] Loading Stage 1 pretrained weights from: {stage1_weights_path}...")

try:
    stage1_model = YOLO(stage1_weights_path)
    model.model.load_state_dict(stage1_model.model.state_dict(), strict=False)
    print("[INFO] Successfully loaded backbone weights. Classification head initialized for 7-class fine-grained task.")
    
except Exception as e:
    print(f"[WARNING] Failed to load Stage 1 weights: {e}")
    print("[INFO] Proceeding with random initialization.")

results = model.train(
    data="/home/hank52052/Dataset/isic/HAM10000/Original_Classes",
    project="/home/hank52052/code/isic/yolo_training",
    name="seven_classes_finetune",
    epochs=1200,
    imgsz=256,
    patience=40,
    batch=64,
    rect=True,
    plots=True,
    
    degrees=0.95, 
    scale=0.9, 
    shear=0.8, 
    flipud=0.9, 
    fliplr=0.9,
    hsv_h=0.8, 
    hsv_s=0.8, 
    hsv_v=0.9, 
    translate=0.8,
)