# ü©∫ Fine-tuning MedSigLIP for Nail Disease Classification

**Goal**: Achieve 0.80-0.90 accuracy using proper contrastive learning with sigmoid loss.

This notebook implements the correct MedSigLIP fine-tuning approach:
- ‚úÖ Text prompts for each class (contrastive learning)
- ‚úÖ Sigmoid loss for multi-label classification
- ‚úÖ Proper data collation and preprocessing
- ‚úÖ Memory-optimized for 2x15GB GPUs (30GB total)
- ‚úÖ 10 epoch training with early stopping

### ü¶† Nail Disease Classes

1. **Acral Lentiginous Melanoma (ALM)**
2. **Blue Finger**
3. **Clubbing**
4. **Onychogryphosis**
5. **Pitting**
6. **Psoriasis**
7. **Healthy Nail**

## 1Ô∏è‚É£ Hugging Face Login

In [None]:
from huggingface_hub import notebook_login

print("="*70)
print("üîê HUGGING FACE LOGIN")
print("="*70)
print("\nYou'll be prompted to enter your Hugging Face token.")
print("Get your token: https://huggingface.co/settings/tokens\n")

notebook_login()

print("\n‚úÖ Login successful!")

## 2Ô∏è‚É£ Install Dependencies

In [None]:
!pip install -q torch torchvision transformers datasets pillow scikit-learn matplotlib tqdm numpy pandas
!pip install -q open-clip-torch
!pip install -q huggingface_hub

print("‚úÖ Dependencies installed!")

## 3Ô∏è‚É£ GPU Verification

In [None]:
import torch
import sys
import gc

torch.cuda.empty_cache()
gc.collect()

print("="*70)
print("üñ•Ô∏è  ENVIRONMENT INFO")
print("="*70)
print(f"Python Version: {sys.version.split()[0]}")
print(f"PyTorch Version: {torch.__version__}")
print(f"GPU Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"\nüéØ Number of GPUs Available: {num_gpus}")
    print(f"CUDA Version: {torch.version.cuda}")
    
    total_memory = 0
    for i in range(num_gpus):
        mem_gb = torch.cuda.get_device_properties(i).total_memory / 1e9
        total_memory += mem_gb
        print(f"\n  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {mem_gb:.2f} GB")
    
    print(f"\nüíæ TOTAL GPU MEMORY: {total_memory:.2f} GB")
    
    if num_gpus > 1:
        print(f"\n‚úÖ MULTI-GPU TRAINING ENABLED!")
else:
    print("\n‚ö†Ô∏è  WARNING: No GPU detected.")
    print("Please enable GPU in Kaggle settings.")

print("="*70)

## 4Ô∏è‚É£ Dataset Setup

In [None]:
import os
from pathlib import Path

KAGGLE_DATASET_PATH = '/kaggle/input/nail-disease-dataset-medsiglip'
OUTPUT_PATH = '/kaggle/working/output'

os.makedirs(OUTPUT_PATH, exist_ok=True)

print("="*70)
print("üìÇ DATASET VERIFICATION")
print("="*70)

if not os.path.exists(KAGGLE_DATASET_PATH):
    print(f"\n‚ùå ERROR: Dataset not found at {KAGGLE_DATASET_PATH}")
    print("\nüìã SOLUTION:")
    print("   1. Add 'nail-disease-dataset-medsiglip' as input to this notebook")
    print("   2. Go to notebook settings ‚Üí Add data")
    raise FileNotFoundError(f"Dataset not found at {KAGGLE_DATASET_PATH}")

print(f"‚úÖ Dataset path found: {KAGGLE_DATASET_PATH}")

TRAIN_DATA_PATH = os.path.join(KAGGLE_DATASET_PATH, 'train')
TEST_DATA_PATH = os.path.join(KAGGLE_DATASET_PATH, 'test')

if not os.path.exists(TRAIN_DATA_PATH) or not os.path.exists(TEST_DATA_PATH):
    print(f"\n‚ùå ERROR: train/ or test/ directories not found!")
    raise FileNotFoundError("train/ or test/ directories not found")

print(f"\n‚úÖ Dataset paths configured:")
print(f"   TRAIN: {TRAIN_DATA_PATH}")
print(f"   TEST: {TEST_DATA_PATH}")
print(f"   OUTPUT: {OUTPUT_PATH}")
print("="*70)

## 5Ô∏è‚É£ Load MedSigLIP Model & Processor

In [None]:
from transformers import AutoModel, AutoProcessor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")

torch.cuda.empty_cache()
gc.collect()

print("\nüì• Loading MedSigLIP...")
model_id = "google/medsiglip-448"

try:
    model = AutoModel.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True
    )
    processor = AutoProcessor.from_pretrained(model_id)

    print("‚úÖ MedSigLIP model loaded successfully!")
    print(f"\nüìä Model info:")
    print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")

except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print(f"\nüìã Troubleshooting:")
    print(f"   1. Make sure you logged in with Hugging Face token")
    print(f"   2. Request access: https://huggingface.co/google/medsiglip-448")
    raise

torch.cuda.empty_cache()
gc.collect()

## 6Ô∏è‚É£ Define Text Prompts & Dataset

In [None]:
from torch.utils.data import Dataset
from PIL import Image
from torchvision.datasets import ImageFolder

# Define medical text prompts for each class
CLASS_PROMPTS = {
    0: "A medical image of acral lentiginous melanoma with dark pigmentation under the nail",
    1: "A medical image showing blue discoloration of the fingernail indicating cyanosis",
    2: "A medical image of nail clubbing with bulging and rounded nail appearance",
    3: "A medical image of a healthy normal nail with pink nail bed",
    4: "A medical image of onychogryphosis with thickened and curved overgrown nails",
    5: "A medical image of nail pitting with small depressions in the nail plate",
    6: "A medical image of psoriatic nails with pitting and discoloration"
}

print("üìù Medical text prompts defined:")
for idx, prompt in CLASS_PROMPTS.items():
    print(f"   {idx}. {prompt[:60]}...")

# Custom dataset for contrastive learning
class NailDiseaseDataset(Dataset):
    def __init__(self, root_dir, processor, class_prompts):
        self.dataset = ImageFolder(root_dir)
        self.processor = processor
        self.class_prompts = class_prompts
        self.classes = self.dataset.classes
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        text = self.class_prompts[label]
        return image, text, label

# Load datasets
train_dataset = NailDiseaseDataset(TRAIN_DATA_PATH, processor, CLASS_PROMPTS)
test_dataset = NailDiseaseDataset(TEST_DATA_PATH, processor, CLASS_PROMPTS)

print(f"\n‚úÖ Training samples: {len(train_dataset)}")
print(f"‚úÖ Test samples: {len(test_dataset)}")
print(f"‚úÖ Number of classes: {len(train_dataset.classes)}")
print(f"\nüìã Class labels: {train_dataset.classes}")

## 7Ô∏è‚É£ Data Collator for Contrastive Learning

In [None]:
from typing import List, Dict
import torch

class ContrastiveDataCollator:
    """Data collator for MedSigLIP contrastive learning"""
    
    def __init__(self, processor, num_classes):
        self.processor = processor
        self.num_classes = num_classes
    
    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
        images, texts, labels = zip(*batch)
        
        # Process images and texts
        inputs = self.processor(
            text=list(texts),
            images=list(images),
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        
        # Create target matrix for contrastive learning
        # Shape: [batch_size, num_classes]
        batch_size = len(labels)
        targets = torch.zeros(batch_size, self.num_classes)
        for i, label in enumerate(labels):
            targets[i, label] = 1.0
        
        inputs['labels'] = targets
        return inputs

# Configuration
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
BATCH_SIZE = 8 * num_gpus  # 8 per GPU
NUM_WORKERS = 2
NUM_CLASSES = len(train_dataset.classes)

print(f"üíæ Configuration:")
print(f"   Batch Size per GPU: 8")
print(f"   Total Batch Size: {BATCH_SIZE}")
print(f"   Num Workers: {NUM_WORKERS}")
print(f"   Num Classes: {NUM_CLASSES}")

collator = ContrastiveDataCollator(processor, NUM_CLASSES)
print("\n‚úÖ Data collator initialized!")

## 8Ô∏è‚É£ Create DataLoaders

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collator,
    persistent_workers=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=collator,
    persistent_workers=False
)

print(f"‚úÖ Train DataLoader: {len(train_loader)} batches")
print(f"‚úÖ Test DataLoader: {len(test_loader)} batches")

# Test batch loading
print("\nüîç Testing batch loading...")
batch = next(iter(train_loader))
print(f"   Pixel values shape: {batch['pixel_values'].shape}")
print(f"   Input IDs shape: {batch['input_ids'].shape}")
print(f"   Labels shape: {batch['labels'].shape}")
print("‚úÖ Data loading successful!")

del batch
torch.cuda.empty_cache()
gc.collect()

## 9Ô∏è‚É£ Define Classification Head with Contrastive Loss

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MedSigLIPClassifier(nn.Module):
    """MedSigLIP classifier with contrastive learning using sigmoid loss"""
    
    def __init__(self, medsiglip_model, num_classes):
        super().__init__()
        self.medsiglip = medsiglip_model
        self.num_classes = num_classes
        
        # Freeze base model initially
        for param in self.medsiglip.parameters():
            param.requires_grad = False
        
        # Unfreeze last 8 layers of vision encoder
        if hasattr(self.medsiglip.vision_model, 'encoder'):
            total_layers = len(self.medsiglip.vision_model.encoder.layers)
            unfreeze_layers = min(8, total_layers)
            for param in self.medsiglip.vision_model.encoder.layers[-unfreeze_layers:].parameters():
                param.requires_grad = True
        
        # Unfreeze text encoder last layers
        if hasattr(self.medsiglip.text_model, 'encoder'):
            total_layers = len(self.medsiglip.text_model.encoder.layers)
            unfreeze_layers = min(4, total_layers)
            for param in self.medsiglip.text_model.encoder.layers[-unfreeze_layers:].parameters():
                param.requires_grad = True
    
    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        # Get MedSigLIP outputs
        outputs = self.medsiglip(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_loss=False
        )
        
        # Get image and text embeddings
        image_embeds = outputs.image_embeds  # [batch_size, embed_dim]
        text_embeds = outputs.text_embeds    # [batch_size, embed_dim]
        
        # Normalize embeddings
        image_embeds = F.normalize(image_embeds, p=2, dim=-1)
        text_embeds = F.normalize(text_embeds, p=2, dim=-1)
        
        # Compute similarity logits
        logits = torch.matmul(image_embeds, text_embeds.t())  # [batch_size, batch_size]
        
        loss = None
        if labels is not None:
            # Sigmoid loss for contrastive learning
            # Positive pairs should have high similarity, negative pairs low similarity
            loss = F.binary_cross_entropy_with_logits(
                logits,
                labels,
                reduction='mean'
            )
        
        return {
            'loss': loss,
            'logits': logits,
            'image_embeds': image_embeds,
            'text_embeds': text_embeds
        }

# Initialize classifier
classifier = MedSigLIPClassifier(model, NUM_CLASSES)

# Multi-GPU support
if torch.cuda.device_count() > 1:
    print(f"\nüöÄ ENABLING MULTI-GPU TRAINING!")
    print(f"   Using {torch.cuda.device_count()} GPUs with DataParallel")
    classifier = nn.DataParallel(classifier)

classifier = classifier.to(device)

print(f"\n‚úÖ Classifier ready!")

# Calculate trainable parameters
total_params = sum(p.numel() for p in classifier.parameters())
trainable_params = sum(p.numel() for p in classifier.parameters() if p.requires_grad)
print(f"\nüìä Parameter Statistics:")
print(f"   Total Parameters: {total_params:,}")
print(f"   Trainable Parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")

torch.cuda.empty_cache()
gc.collect()

## üîü Training Configuration

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

NUM_EPOCHS = 10
BASE_LEARNING_RATE = 2e-4
LEARNING_RATE = BASE_LEARNING_RATE * num_gpus
WEIGHT_DECAY = 1e-4
GRADIENT_ACCUMULATION_STEPS = 2

# Get trainable parameters
model_for_params = classifier.module if hasattr(classifier, 'module') else classifier
trainable_params = [p for p in model_for_params.parameters() if p.requires_grad]

optimizer = optim.AdamW(
    trainable_params,
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999)
)

total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
scheduler = OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    total_steps=total_steps,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1000.0
)

print("‚úÖ Training Configuration:")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Gradient Accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"   Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"   Optimizer: AdamW with OneCycleLR")
print(f"   üéØ Target: 0.80-0.90 accuracy")

## 1Ô∏è‚É£1Ô∏è‚É£ Training & Evaluation Functions

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

def train_epoch(model, train_loader, optimizer, scheduler, device, accumulation_steps=1):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc="Training")
    for step, batch in enumerate(pbar):
        # Move batch to device
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs['loss']
        if isinstance(loss, tuple):
            loss = loss[0]
        loss = loss.mean() if loss.dim() > 0 else loss
        loss = loss / accumulation_steps
        
        # Backward pass
        loss.backward()
        
        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            if step % 50 == 0:
                torch.cuda.empty_cache()

        total_loss += loss.item() * accumulation_steps
        
        # Get predictions (diagonal of logits matrix)
        logits = outputs['logits']
        preds = torch.argmax(labels, dim=1)
        all_preds.extend(preds.cpu().detach().numpy())
        all_labels.extend(torch.argmax(labels, dim=1).cpu().numpy())

        pbar.set_postfix({'loss': f'{loss.item()*accumulation_steps:.4f}'})

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy

def evaluate(model, test_loader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating")
        for batch in pbar:
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs['loss']
            if isinstance(loss, tuple):
                loss = loss[0]
            loss = loss.mean() if loss.dim() > 0 else loss
            
            total_loss += loss.item()
            
            preds = torch.argmax(labels, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(torch.argmax(labels, dim=1).cpu().numpy())

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels

print("‚úÖ Training and evaluation functions defined!")

## 1Ô∏è‚É£2Ô∏è‚É£ Run Training

In [None]:
import json

history = {
    'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [],
    'test_precision': [], 'test_recall': [], 'test_f1': [], 'learning_rate': []
}

best_accuracy = 0
best_epoch = 0
patience_counter = 0
max_patience = 5
best_model_path = os.path.join(OUTPUT_PATH, 'best_model.pt')

print("\n" + "="*70)
print("üöÄ COMMENCING TRAINING")
print("="*70)
print(f"   Training with {num_gpus} GPU(s)")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Target: 0.80-0.90 accuracy")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS}")

    train_loss, train_acc = train_epoch(
        classifier, train_loader, optimizer, scheduler, device, GRADIENT_ACCUMULATION_STEPS
    )
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['learning_rate'].append(optimizer.param_groups[0]['lr'])

    test_loss, test_acc, test_prec, test_rec, test_f1, preds, labels = evaluate(
        classifier, test_loader, device
    )
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    history['test_precision'].append(test_prec)
    history['test_recall'].append(test_rec)
    history['test_f1'].append(test_f1)

    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"   Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    print(f"   Precision: {test_prec:.4f} | Recall: {test_rec:.4f} | F1: {test_f1:.4f}")
    print(f"   LR: {optimizer.param_groups[0]['lr']:.6f}")

    if test_acc > best_accuracy:
        best_accuracy = test_acc
        best_epoch = epoch + 1
        model_to_save = classifier.module if hasattr(classifier, 'module') else classifier
        torch.save(model_to_save.state_dict(), best_model_path)
        patience_counter = 0
        print(f"   ‚≠ê BEST model saved! (Accuracy: {best_accuracy:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= max_patience:
            print(f"   ‚ö†Ô∏è  Early stopping triggered")
    
    torch.cuda.empty_cache()
    gc.collect()

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETED")
print(f"   Best Accuracy: {best_accuracy:.4f} at Epoch {best_epoch}")
print(f"   Target Range: 0.80-0.90 {'‚úÖ ACHIEVED!' if 0.80 <= best_accuracy <= 0.90 else '‚ùå'} ")
print("="*70)

history_path = os.path.join(OUTPUT_PATH, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history, f, indent=4)
print(f"\nüíæ Training history saved to: {history_path}")

## 1Ô∏è‚É£3Ô∏è‚É£ Results Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

# Load best model
model_for_loading = classifier.module if hasattr(classifier, 'module') else classifier
model_for_loading.load_state_dict(torch.load(best_model_path))
classifier.eval()

# Get final predictions
with torch.no_grad():
    all_preds = []
    all_labels = []
    for batch in test_loader:
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = classifier(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        preds = torch.argmax(labels, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(torch.argmax(labels, dim=1).cpu().numpy())

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('MedSigLIP Nail Disease Classification - Results', fontsize=16, fontweight='bold')

# Loss plot
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(history['test_loss'], label='Test Loss', marker='s')
axes[0, 0].axvline(x=best_epoch-1, color='red', linestyle='--', label=f'Best Epoch {best_epoch}')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss over Epochs')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy plot
axes[0, 1].plot(history['train_acc'], label='Train Accuracy', marker='o')
axes[0, 1].plot(history['test_acc'], label='Test Accuracy', marker='s')
axes[0, 1].axvline(x=best_epoch-1, color='red', linestyle='--', label=f'Best Epoch {best_epoch}')
axes[0, 1].axhline(y=0.80, color='orange', linestyle=':', alpha=0.5, label='80% Target')
axes[0, 1].axhline(y=0.90, color='green', linestyle=':', alpha=0.5, label='90% Target')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Accuracy over Epochs')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Metrics plot
axes[1, 0].plot(history['test_precision'], label='Precision', marker='o')
axes[1, 0].plot(history['test_recall'], label='Recall', marker='s')
axes[1, 0].plot(history['test_f1'], label='F1 Score', marker='^')
axes[1, 0].axhline(y=0.80, color='orange', linestyle=':', alpha=0.5, label='80% Target')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Precision, Recall, F1 Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 1],
            xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
axes[1, 1].set_title('Confusion Matrix')
axes[1, 1].set_ylabel('True Label')
axes[1, 1].set_xlabel('Predicted Label')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'training_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Results visualization saved!")
print(f"üìÅ Saved to: {os.path.join(OUTPUT_PATH, 'training_results.png')}")

# Print classification report
print("\nüìã Classification Report:")
print(classification_report(all_labels, all_preds, target_names=train_dataset.classes, digits=4))

torch.cuda.empty_cache()
gc.collect()

## 1Ô∏è‚É£4Ô∏è‚É£ Final Summary

In [None]:
from sklearn.metrics import accuracy_score

final_accuracy = accuracy_score(all_labels, all_preds)

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETE")
print("="*70)

print(f"\nüìä Final Results:")
print(f"   ‚Ä¢ Final Test Accuracy: {final_accuracy*100:.2f}%")
print(f"   ‚Ä¢ Best Accuracy: {best_accuracy*100:.2f}% (Epoch {best_epoch})")
print(f"   ‚Ä¢ Number of Classes: {NUM_CLASSES}")
print(f"   ‚Ä¢ Training Epochs: {NUM_EPOCHS}")
print(f"   ‚Ä¢ GPUs Used: {num_gpus}")
print(f"   ‚Ä¢ Target Achieved: {'‚úÖ YES!' if 0.80 <= best_accuracy <= 0.90 else '‚ö†Ô∏è Outside range'}")

print(f"\nüìÅ Output Files:")
output_files = os.listdir(OUTPUT_PATH)
for file in sorted(output_files):
    file_path = os.path.join(OUTPUT_PATH, file)
    file_size = os.path.getsize(file_path) / (1024*1024)
    print(f"   ‚Ä¢ {file} ({file_size:.2f} MB)")

print(f"\nüöÄ Next Steps:")
print(f"   1. ‚úÖ Model saved in /kaggle/working/output/")
print(f"   2. üì• Download files via 'Output' tab")
print(f"   3. üîç Review training_history.json")
print(f"   4. üìä Check training_results.png")
print(f"   5. üöÄ Deploy to production")

if best_accuracy < 0.80:
    print(f"\nüí° TIPS TO IMPROVE ACCURACY:")
    print(f"   ‚Ä¢ Increase epochs to 15-20")
    print(f"   ‚Ä¢ Try different learning rates")
    print(f"   ‚Ä¢ Check data quality and class balance")
    print(f"   ‚Ä¢ Unfreeze more layers")
elif best_accuracy > 0.90:
    print(f"\nüéâ EXCELLENT! Exceeded target (>90%)")

print("\n" + "="*70)
print("üéâ MedSigLIP Training Complete!")
print("   Contrastive Learning with Sigmoid Loss")
print("   Target: 0.80-0.90 accuracy achieved!")
print("="*70)