# # Skin Cancer Classification and Segmentation

# ## 1. Environment Setup

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import InterpolationMode
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ## 2. Data Loading and Splitting

In [None]:
# Load metadata with column names
column_names = ['image_id'] + [f'class_{i}' for i in range(7)]
metadata = pd.read_csv('archive/GroundTruth.csv', header=None, names=column_names)

# Define class names
class_names = [
    'melanoma',
    'nevus',
    'basal_cell_carcinoma',
    'actinic_keratosis',
    'benign_keratosis',
    'dermatofibroma',
    'vascular_lesion'
]

# Convert one-hot to dx column
metadata['dx'] = metadata.iloc[:, 1:8].idxmax(axis=1)
metadata['dx'] = metadata['dx'].str.replace('class_', '').map(lambda x: class_names[int(x)])

# Check for duplicates
if metadata.duplicated().any():
    print("Warning: Duplicates found. Removing duplicates.")
    metadata = metadata.drop_duplicates()

# Split data
try:
    train_df, temp_df = train_test_split(metadata, test_size=0.3, stratify=metadata['dx'])
    val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['dx'])
except ValueError as e:
    print(f"Splitting error: {e}")
    print("Consider merging rare classes or using a different strategy.")
    raise

# Reset indices
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

# Print stats
print(f"\nTraining set: {len(train_df)}")
print(f"Validation set: {len(val_df)}")
print(f"Test set: {len(test_df)}")

# Check corrupt samples with optional mask requirement
def check_corrupt_samples(df, image_dir, mask_dir=None, require_masks=False):
    corrupt_indices = []
    for idx in tqdm(range(len(df)), desc="Checking corrupt files"):
        img_name = df.loc[idx, 'image_id'] + '.jpg'
        img_path = os.path.join(image_dir, img_name)
        
        # Check image
        try:
            with Image.open(img_path) as img:
                img.verify()
        except (IOError, OSError, ValueError) as e:
            print(f"Corrupt or missing image {img_path}: {e}")
            corrupt_indices.append(idx)
            continue
        
        # Check mask only if required
        if require_masks and mask_dir:
            mask_name = df.loc[idx, 'image_id'] + '_segmentation.png'  # Updated to match actual naming
            mask_path = os.path.join(mask_dir, mask_name)
            if os.path.exists(mask_path):
                try:
                    with Image.open(mask_path) as mask:
                        mask.verify()
                except (IOError, OSError, ValueError) as e:
                    print(f"Corrupt or missing mask {mask_path}: {e}")
                    corrupt_indices.append(idx)
            else:
                print(f"Missing mask: {mask_name}")
                corrupt_indices.append(idx)
    
    return corrupt_indices

# Process datasets for classification (images only)
for df, name in [(train_df, "train"), (val_df, "val"), (test_df, "test")]:
    corrupt = check_corrupt_samples(df, 'archive/images', mask_dir=None, require_masks=False)
    if corrupt:
        df.drop(corrupt, inplace=True)
    df.reset_index(drop=True, inplace=True)
    print(f"{name.capitalize()} set size after image check: {len(df)}")

# Create a separate copy for segmentation (images + masks)
seg_train_df = train_df.copy()
seg_val_df = val_df.copy()
seg_test_df = test_df.copy()

for df, name in [(seg_train_df, "train"), (seg_val_df, "val"), (seg_test_df, "test")]:
    corrupt = check_corrupt_samples(df, 'archive/images', 'archive/masks', require_masks=True)
    if corrupt:
        df.drop(corrupt, inplace=True)
    df.reset_index(drop=True, inplace=True)
    print(f"Segmentation {name.capitalize()} set size after mask check: {len(df)}")
    if len(df) == 0:
        print(f"Warning: Segmentation {name} set is empty. Masks may be missing in 'archive/masks'.")

# Create label mapping
all_data = pd.concat([train_df, val_df, test_df])
label_mapping = {label: idx for idx, label in enumerate(sorted(all_data['dx'].unique()))}

print("\nClass distributions:")
print("Training:", train_df['dx'].value_counts(normalize=True))
print("Validation:", val_df['dx'].value_counts(normalize=True))
print("Test:", test_df['dx'].value_counts(normalize=True))
print("\nLabel mapping:", label_mapping)

# ## 3. Data Processing

In [None]:
# Dataset Classes
class ClassificationDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
        self.label_mapping = label_mapping

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image_id'] + '.jpg'
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        label = self.label_mapping[self.df.iloc[idx]['dx']]
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.long)

class SegmentationDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir, transform=None, mask_transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image_id'] + '.jpg'
        mask_name = self.df.iloc[idx]['image_id'] + '_segmentation.png'  # Updated to match actual naming
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
            
        return image, mask

# Transforms
img_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize(224, interpolation=InterpolationMode.NEAREST),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

# Create datasets
classification_train = ClassificationDataset(train_df, 'archive/images', img_transform)
classification_val = ClassificationDataset(val_df, 'archive/images', img_transform)
classification_test = ClassificationDataset(test_df, 'archive/images', img_transform)

# Create segmentation datasets
segmentation_train = SegmentationDataset(seg_train_df, 'archive/images', 'archive/masks', 
                                       img_transform, mask_transform)
segmentation_val = SegmentationDataset(seg_val_df, 'archive/images', 'archive/masks',
                                     img_transform, mask_transform)
segmentation_test = SegmentationDataset(seg_test_df, 'archive/images', 'archive/masks',
                                      img_transform, mask_transform)

# Check for empty datasets
for dataset, name in [(classification_train, "classification_train"),
                      (classification_val, "classification_val"),
                      (segmentation_train, "segmentation_train"),
                      (segmentation_val, "segmentation_val")]:
    if len(dataset) == 0:
        raise ValueError(f"Dataset {name} is empty after preprocessing. Check file paths or data availability.")

# ## 4. Classification Model

In [None]:
class SkinCancerClassifier(nn.Module):
    def __init__(self, num_classes=7):
        super().__init__()
        # Load pre-trained ResNet-18
        self.base_model = models.resnet18(pretrained=True)
        
        # Replace final fully connected layer
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_features, num_classes)
        
    def forward(self, x):
        return self.base_model(x)

# Instantiate and verify model
classification_model = SkinCancerClassifier().to(device)
print(f"Classification model parameters: {sum(p.numel() for p in classification_model.parameters()):,}")

# ## 5. Segmentation Model

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
    def forward(self, x):
        return self.block(x)
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        # Encoder (ResNet-18 backbone)
        self.encoder = models.resnet18(pretrained=True)
        
        # Encoder layers
        self.conv1 = nn.Sequential(
            self.encoder.conv1,
            self.encoder.bn1,
            self.encoder.relu,
            self.encoder.maxpool
        )
        self.encoder1 = self.encoder.layer1  # Output: 64 channels
        self.encoder2 = self.encoder.layer2  # Output: 128 channels
        self.encoder3 = self.encoder.layer3  # Output: 256 channels
        self.encoder4 = self.encoder.layer4  # Output: 512 channels
        
        # Decoder layers
        self.decoder4 = DecoderBlock(512, 256)  # Upsample 512 -> 256
        self.decoder3 = DecoderBlock(256, 128)  # Upsample 256 -> 128
        self.decoder2 = DecoderBlock(128, 64)   # Upsample 128 -> 64
        self.decoder1 = DecoderBlock(64, 64)    # Upsample 64 -> 64
        
        # Final output layer
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, out_channels, kernel_size=1)  # Output: 1 channel (binary mask)
        )

    def forward(self, x):
        # Encoder
        x = self.conv1(x)       # Initial conv + maxpool
        e1 = self.encoder1(x)   # Layer 1
        e2 = self.encoder2(e1)  # Layer 2
        e3 = self.encoder3(e2)  # Layer 3
        e4 = self.encoder4(e3)  # Layer 4
        
        # Decoder with skip connections
        d4 = self.decoder4(e4) + e3  # Upsample + skip connection
        d3 = self.decoder3(d4) + e2  # Upsample + skip connection
        d2 = self.decoder2(d3) + e1  # Upsample + skip connection
        d1 = self.decoder1(d2)       # Final upsample
        
        # Final output
        return self.final(d1)

# Instantiate and verify model
segmentation_model = UNet().to(device)
print(f"Segmentation model parameters: {sum(p.numel() for p in segmentation_model.parameters()):,}")

# ## 6. Training Utilities

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, task='classification'):
    """
    Train a model for classification or segmentation.
    
    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        criterion: Loss function.
        optimizer: Optimizer.
        num_epochs (int): Number of epochs to train.
        task (str): 'classification' or 'segmentation'.
    
    Returns:
        dict: Training and validation history.
    """
    best_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        progress = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for inputs, targets in progress:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if task == 'segmentation':
                loss = criterion(outputs, targets)
            else:
                loss = criterion(outputs, targets)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            progress.set_postfix({'loss': loss.item()})
        
        # Calculate average training loss for the epoch
        train_loss = running_loss / len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, task)
        history['val_loss'].append(val_loss)
        if task == 'classification':
            history['val_acc'].append(val_acc)
        
        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), f'best_{task}_model.pth')
        
        # Print epoch results
        print(f"{task.capitalize()} Epoch {epoch+1}: "
              f"Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}" + 
              (f", Val Acc: {val_acc:.2f}%" if task == 'classification' else ""))
    
    return history


def evaluate_model(model, loader, criterion, task='classification'):
    """
    Evaluate a model on a given dataset.
    
    Args:
        model (nn.Module): The model to evaluate.
        loader (DataLoader): DataLoader for evaluation data.
        criterion: Loss function.
        task (str): 'classification' or 'segmentation'.
    
    Returns:
        tuple: (average loss, accuracy (if classification))
    """
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            if task == 'segmentation':
                loss = criterion(outputs, targets)
                total_loss += loss.item() * inputs.size(0)
            else:
                loss = criterion(outputs, targets)
                total_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == targets).sum().item()
                total += targets.size(0)
    
    avg_loss = total_loss / len(loader.dataset)
    if task == 'classification':
        accuracy = 100 * correct / total
        return avg_loss, accuracy
    return avg_loss, None

# ## 7. Model Training

In [None]:
# Train Classification Model
clf_criterion = nn.CrossEntropyLoss()
clf_optimizer = optim.Adam(classification_model.parameters(), lr=1e-4)
clf_history = train_model(classification_model, 
                         DataLoader(classification_train, 32, shuffle=True),
                         DataLoader(classification_val, 32),
                         clf_criterion, clf_optimizer, 
                         num_epochs=15, task='classification')

# Train Segmentation Model
seg_criterion = nn.BCEWithLogitsLoss()
seg_optimizer = optim.Adam(segmentation_model.parameters(), lr=1e-4)
seg_history = train_model(segmentation_model, 
                         DataLoader(segmentation_train, 16, shuffle=True),
                         DataLoader(segmentation_val, 16),
                         seg_criterion, seg_optimizer,
                         num_epochs=25, task='segmentation')

# ## 8. Evaluation and Report

In [None]:
def calculate_iou_dice(pred_mask, true_mask):
    """Calculate Intersection over Union (IoU) and Dice coefficient."""
    pred_mask = (pred_mask > 0.5).float()
    true_mask = (true_mask > 0.5).float()
    
    intersection = (pred_mask * true_mask).sum()
    union = pred_mask.sum() + true_mask.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (pred_mask.sum() + true_mask.sum() + 1e-6)
    return iou.item(), dice.item()

def generate_classification_report(model, loader):
    """Generate classification report with metrics."""
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    print("Classification Report:")
    print(classification_report(all_targets, all_preds, target_names=label_mapping.keys()))
    
    plt.figure(figsize=(10,8))
    sns.heatmap(confusion_matrix(all_targets, all_preds), 
                annot=True, fmt='d', cmap='Blues',
                xticklabels=label_mapping.keys(),
                yticklabels=label_mapping.keys())
    plt.title('Confusion Matrix')
    plt.show()

def generate_segmentation_report(model, loader):
    """Generate segmentation metrics and visualizations."""
    model.eval()
    ious, dices = [], []
    
    with torch.no_grad():
        for inputs, masks in loader:
            inputs = inputs.to(device)
            outputs = torch.sigmoid(model(inputs))
            preds = (outputs > 0.5).float()
            
            for pred, mask in zip(preds, masks.to(device)):
                iou, dice = calculate_iou_dice(pred, mask)
                ious.append(iou)
                dices.append(dice)
    
    print(f"Mean IoU: {np.mean(ious):.4f}")
    print(f"Mean Dice: {np.mean(dices):.4f}")
    
    # Visualization
    sample = next(iter(loader))
    inputs, masks = sample
    outputs = torch.sigmoid(model(inputs.to(device)))
    preds = (outputs > 0.5).float().cpu()
    
    plt.figure(figsize=(15,5))
    for i in range(3):
        plt.subplot(3, 4, i * 4 + 1)
        plt.imshow(inputs[i].permute(1, 2, 0).numpy() * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(3, 4, i * 4 + 2)
        plt.imshow(masks[i].squeeze(), cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')
        
        plt.subplot(3, 4, i * 4 + 3)
        plt.imshow(preds[i].squeeze(), cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')
        
        plt.subplot(3, 4, i * 4 + 4)
        plt.imshow((preds[i].squeeze() > 0.5).astype(float), cmap='jet', alpha=0.5)
        plt.imshow(inputs[i].permute(1, 2, 0).numpy() * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])
        plt.title('Overlay')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Generate final reports
print("\nFinal Classification Performance:")
generate_classification_report(classification_model,  DataLoader(classification_test, 32))

print("\nFinal Segmentation Performance:")
generate_segmentation_report(segmentation_model,
                            DataLoader(segmentation_test, 16))

# Plot training curves
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(clf_history['train_loss'], label='Train Loss')
plt.plot(clf_history['val_loss'], label='Val Loss')
plt.plot(clf_history['val_acc'], label='Val Accuracy')
plt.title('Classification Training')
plt.legend()

plt.subplot(1,2,2)
plt.plot(seg_history['train_loss'], label='Train Loss')
plt.plot(seg_history['val_loss'], label='Val Loss')
plt.title('Segmentation Training')
plt.legend()
plt.show()

# ## 9. Single Model for Both Tasks

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, num_classes=7):
        super().__init__()
        # Shared encoder
        self.base_model = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(
            self.base_model.conv1,
            self.base_model.bn1,
            self.base_model.relu,
            self.base_model.maxpool,
            self.base_model.layer1,
            self.base_model.layer2,
            self.base_model.layer3,
            self.base_model.layer4
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
        
        # Segmentation decoder
        self.decoder = nn.Sequential(
            DecoderBlock(512, 256),
            DecoderBlock(256, 128),
            DecoderBlock(128, 64),
            DecoderBlock(64, 64),
            nn.Conv2d(64, 1, kernel_size=1)
        )

    def forward(self, x):
        features = self.encoder(x)
        cls_out = self.classifier(features)
        seg_out = self.decoder(features)
        return cls_out, seg_out

# Initialize and verify
multi_task_model = MultiTaskModel().to(device)
print(f"Parameters: {sum(p.numel() for p in multi_task_model.parameters()):,}")

# Create multi-task dataset
class MultiTaskDataset(Dataset):
    def __init__(self, df, image_dir, mask_dir):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize(224, InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, f"{self.df.iloc[idx]['image_id']}.jpg")
        mask_path = os.path.join(self.mask_dir, f"{self.df.iloc[idx]['image_id']}_mask.jpg")
        label = label_mapping[self.df.iloc[idx]['dx']]
        return self.transform(Image.open(img_path)), (
            torch.tensor(label, dtype=torch.long),
            self.mask_transform(Image.open(mask_path).convert('L'))
        )
    def __len__(self):
        return len(self.df)

# Training loop
multitask_train = MultiTaskDataset(train_df, 'archive/images', 'archive/masks')
train_loader = DataLoader(multitask_train, 16, shuffle=True)

clf_criterion = nn.CrossEntropyLoss()
seg_criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(multi_task_model.parameters(), lr=1e-4)

for epoch in range(10):
    multi_task_model.train()
    total_loss = 0.0
    
    for images, (labels, masks) in tqdm(train_loader):
        images, labels, masks = images.to(device), labels.to(device), masks.to(device)
        
        optimizer.zero_grad()
        cls_pred, seg_pred = multi_task_model(images)
        loss = clf_criterion(cls_pred, labels) + seg_criterion(seg_pred, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}: Loss {total_loss/len(train_loader):.4f}")

# Evaluation
test_loader = DataLoader(
    MultiTaskDataset(test_df, 'archive/images', 'archive/masks'), 
    batch_size=16
)

with torch.no_grad():
    total_acc, total_iou, total_dice = 0, 0, 0
    for images, (labels, masks) in test_loader:
        images, labels, masks = images.to(device), labels.to(device), masks.to(device)
        cls_pred, seg_pred = multi_task_model(images)
        
        total_acc += (cls_pred.argmax(1) == labels).float().mean().item()
        iou, dice = calculate_iou_dice(seg_pred, masks)
        total_iou += iou
        total_dice += dice

    print(f"Accuracy: {total_acc/len(test_loader):.4f}, "
          f"IoU: {total_iou/len(test_loader):.4f}, "
          f"Dice: {total_dice/len(test_loader):.4f}")