In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from transformers import DeiTForImageClassification
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc

torch.cuda.empty_cache()

# Configuration
CONFIG = {
    'seed': 42,
    'model_name': "facebook/deit-base-patch16-224",  # Changed to DeiT model
    'batch_size': 16,
    'num_epochs': 300,
    'learning_rate': 2e-4,  # Lowered learning rate for fine-tuning
    'weight_decay': 1e-4,
    'train_dir': '/Users/user/Desktop/project/image/train2',
    'val_dir': '/Users/user/Desktop/project/image/vaildation2',  # Update these paths to your actual paths
    'pretrained_weights': 'best_model_0114.pth',
    'new_checkpoint_path': 'deit_model_checkpoint.pth',  # Updated filename
    'accumulation_steps': 2
}

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def load_pretrained_model(model, weights_path, device):
    """Load pretrained weights into model"""
    if not os.path.exists(weights_path):
        print(f"Warning: Weights file {weights_path} not found!")
        return model, 0
        
    try:
        checkpoint = torch.load(weights_path, map_location=device)
        
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = 0
            print("Loaded checkpoint and starting from epoch 1")
        else:
            model.load_state_dict(checkpoint)
            start_epoch = 0
            print("Loaded model weights")
            
        print(f"Successfully loaded weights from {weights_path}")
        return model, start_epoch
    except Exception as e:
        print(f"Error loading weights: {e}")
        print("Starting with fresh model")
        return model, 0

def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies, save_dir='./plots'):
    """Plot training and validation loss/accuracy curves"""
    os.makedirs(save_dir, exist_ok=True)
    
    # Plot Loss Curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'deit_loss_curves.png'))
    plt.close()
    
    # Plot Accuracy Curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'deit_accuracy_curves.png'))
    plt.close()

def plot_confusion_matrix(y_true, y_pred, save_dir='./plots'):
    """Plot confusion matrix"""
    os.makedirs(save_dir, exist_ok=True)
    
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(os.path.join(save_dir, 'deit_confusion_matrix.png'))
    plt.close()

def plot_roc_curve(y_true, y_prob, save_dir='./plots'):
    """Plot ROC curve"""
    os.makedirs(save_dir, exist_ok=True)
    
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'deit_roc_curve.png'))
    plt.close()

def plot_performance_metrics(precisions, recalls, f1_scores, save_dir='./plots'):
    """Plot performance metrics over epochs"""
    os.makedirs(save_dir, exist_ok=True)
    
    epochs = range(1, len(precisions) + 1)
    
    plt.figure(figsize=(12, 6))
    plt.plot(epochs, precisions, 'b-', label='Precision')
    plt.plot(epochs, recalls, 'r-', label='Recall')
    plt.plot(epochs, f1_scores, 'g-', label='F1-Score')
    plt.title('Performance Metrics Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'deit_performance_metrics.png'))
    plt.close()

def validate_with_predictions(model, val_loader, criterion, device):
    """Validation function that returns predictions and probabilities"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_targets = []
    all_probabilities = []
    
    with torch.no_grad():
        for data, targets in tqdm(val_loader, desc="Validating"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            loss = criterion(outputs.logits, targets)
            
            total_loss += loss.item()
            probabilities = torch.softmax(outputs.logits, dim=1)
            _, predicted = outputs.logits.max(1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
    recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)
    
    return (total_loss / len(val_loader), accuracy, precision, recall, f1, 
            all_targets, all_predictions, all_probabilities)

def mixup_data(x, y, alpha=0.2):
    """Performs mixup on the input data and returns mixed inputs, pairs of targets, and lambda"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Computes the mixup loss"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_one_epoch(model, train_loader, optimizer, criterion, device, scheduler=None, accumulation_steps=1, clip_value=1.0):
    """Enhanced training function with mixup and gradient clipping"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc="Training")):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Apply mixup
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)
        
        # Forward pass
        outputs = model(inputs).logits
        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        
        # Scale loss for gradient accumulation
        loss = loss / accumulation_steps
        loss.backward()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a).sum().float() + 
                   (1 - lam) * predicted.eq(targets_b).sum().float())
    
    acc = (correct / total).cpu().item()  # Move to CPU and get item
    return total_loss / len(train_loader), acc


def main():
    set_seed(CONFIG['seed'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    try:
        # Initialize DeiT model instead of ViT
        print("Initializing Data-efficient Vision Transformer (DeiT) model...")
        model = DeiTForImageClassification.from_pretrained(
            CONFIG['model_name'],
            num_labels=2,
            ignore_mismatched_sizes=True
        )
        
        # Enhanced classifier head with progressive dropout
        model.classifier = torch.nn.Sequential(
            torch.nn.Linear(model.config.hidden_size, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(256, 2)
        )
        
        # Note: If you're trying to load weights from a standard ViT model, they might not be compatible
        # with DeiT due to architecture differences. You might need to start with fresh weights.
        try:
            model, start_epoch = load_pretrained_model(model, CONFIG['pretrained_weights'], device)
        except:
            print("Could not load pretrained weights - architecture mismatch. Starting with fresh model.")
            start_epoch = 0
            
        model.to(device)
        
        # DeiT-specific data transforms
        # DeiT uses stronger augmentation than standard ViT
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(45),  # Stronger rotation
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),  # Stronger color jitter
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.2)),  # Stronger affine
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.3),  # Increased erasing probability
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Dataset loading
        print("Loading datasets...")
        train_dataset = ImageFolder(CONFIG['train_dir'], transform=train_transform)
        val_dataset = ImageFolder(CONFIG['val_dir'], transform=val_transform)
        
        print(f"Number of training samples: {len(train_dataset)}")
        print(f"Number of validation samples: {len(val_dataset)}")
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=CONFIG['batch_size'], 
            shuffle=True, 
            num_workers=0
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=CONFIG['batch_size'], 
            shuffle=False, 
            num_workers=0
        )
        
        # Enhanced optimizer settings for DeiT
        # DeiT typically uses a different learning rate schedule
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=CONFIG['learning_rate'],
            weight_decay=0.05,  # Increased weight decay for DeiT
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Cosine learning rate scheduler with warmup - better for DeiT
        num_training_steps = len(train_loader) * CONFIG['num_epochs']
        num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,  # Restart every 10 epochs
            T_mult=1,
            eta_min=1e-6
        )
        
        class_weights = torch.tensor([1.0, 0.42]).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        # Training metrics storage
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        best_val_accuracy = 0
        
        print("Starting training with Data-efficient Vision Transformer (DeiT)...")
        for epoch in range(start_epoch, CONFIG['num_epochs']):
            print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
            
            # Train
            train_loss, train_acc = train_one_epoch(
                model, train_loader, optimizer, criterion, device,
                scheduler=None,  # We'll step the scheduler after each epoch
                accumulation_steps=CONFIG['accumulation_steps'],
                clip_value=1.0
            )
            
            # Step the scheduler after each epoch
            scheduler.step()
            
            # Validate
            val_loss, val_acc, precision, recall, f1, targets, predictions, probabilities = validate_with_predictions(
                model, val_loader, criterion, device
            )
            
            # Store metrics (ensure they're CPU values)
            train_losses.append(float(train_loss))
            val_losses.append(float(val_loss))
            train_accuracies.append(float(train_acc))
            val_accuracies.append(float(val_acc))
            precisions.append(float(precision))
            recalls.append(float(recall))
            f1_scores.append(float(f1))
            
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
            print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")
            print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")
            
            # Save best model
            if val_acc > best_val_accuracy:
                best_val_accuracy = val_acc
                print(f"Saving best model with validation accuracy: {val_acc*100:.2f}%")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_accuracy': val_acc,
                }, CONFIG['new_checkpoint_path'])
        
        # Plot visualizations
        print("\nGenerating visualization plots...")
        plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)
        plot_confusion_matrix(targets, predictions)
        plot_roc_curve(targets, probabilities)
        plot_performance_metrics(precisions, recalls, f1_scores)
        print("All plots have been saved in the 'plots' directory")
        
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        import traceback
        print(traceback.format_exc())

if __name__ == "__main__":
    main()