In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from transformers import ViTForImageClassification, ViTImageProcessor
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    confusion_matrix, 
    classification_report
)
from tqdm import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Configuration
CONFIG = {
    'seed': 42,
    'model_name': "google/vit-base-patch16-224",
    'batch_size': 16,
    'num_epochs': 30,
    'learning_rate': 1e-4,
    'weight_decay': 1e-2,
    'lasso_lambda': 1e-3,
    'train_dir': '/Users/user/Desktop/project/image/train2',
    'val_dir': '/Users/user/Desktop/project/image/vaildation2',
    'checkpoint_path': 'vit.pth'
}

def main():
    # Set random seed
    set_seed(CONFIG['seed'])

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load pre-trained ViT model and image processor
    model = ViTForImageClassification.from_pretrained(
        CONFIG['model_name'], 
        num_labels=2, 
        ignore_mismatched_sizes=True
    )
    image_processor = ViTImageProcessor.from_pretrained(CONFIG['model_name'])

    # Custom classifier head
    model.classifier = nn.Sequential(
        nn.Linear(model.config.hidden_size, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(1024, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, 2)
    )

    # Load pre-trained weights if exists
    start_epoch = 0
    if os.path.exists(CONFIG['checkpoint_path']):
        try:
            checkpoint = torch.load(CONFIG['checkpoint_path'], map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Loaded pre-trained model from {CONFIG['checkpoint_path']} at epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")

    # Move model to device
    model.to(device)

    # Define data transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load datasets
    train_dataset = ImageFolder(CONFIG['train_dir'], transform=train_transform)
    val_dataset = ImageFolder(CONFIG['val_dir'], transform=val_transform)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=True, 
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=CONFIG['batch_size'], 
        num_workers=4
    )

    # Custom loss function with Lasso regularization
    def lasso_loss(model, criterion, outputs, labels, lasso_lambda):
        # Standard cross-entropy loss
        standard_loss = criterion(outputs.logits, labels)
        
        # Lasso regularization (L1 penalty)
        l1_penalty = sum(p.abs().sum() for p in model.classifier.parameters() if p.requires_grad)
        
        # Combined loss
        return standard_loss + lasso_lambda * l1_penalty

    # Optimizer and loss function
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=CONFIG['learning_rate'], 
        weight_decay=CONFIG['weight_decay']
    )
    criterion = nn.CrossEntropyLoss()

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=CONFIG['learning_rate'], 
        epochs=CONFIG['num_epochs'], 
        steps_per_epoch=len(train_loader)
    )

    # Training metrics storage
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []
    val_precisions, val_recalls, val_f1_scores = [], [], []

    # Best model tracking
    best_val_accuracy = 0

    # Training loop
    for epoch in range(start_epoch, CONFIG['num_epochs']):
        # Training phase
        model.train()
        epoch_loss, epoch_accuracy = 0, 0

        for data, label in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']} - Training"):
            data, label = data.to(device), label.to(device)

            optimizer.zero_grad()
            outputs = model(data, labels=label)
            
            # Custom loss with Lasso regularization
            loss = lasso_loss(model, criterion, outputs, label, CONFIG['lasso_lambda'])

            loss.backward()
            optimizer.step()
            scheduler.step()

            acc = (outputs.logits.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc.item() / len(train_loader)
            epoch_loss += loss.item() / len(train_loader)

        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        # Validation phase
        model.eval()
        all_preds, all_labels, all_probs = [], [], []
        epoch_val_loss = 0

        with torch.no_grad():
            for data, label in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']} - Validation"):
                data, label = data.to(device), label.to(device)

                outputs = model(data, labels=label)
                val_loss = criterion(outputs.logits, label)

                probs = torch.nn.functional.softmax(outputs.logits, dim=1)
                preds = torch.argmax(probs, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(label.cpu().numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())

                epoch_val_loss += val_loss.item() / len(val_loader)

        # Calculate validation metrics
        val_accuracy = accuracy_score(all_labels, all_preds)
        val_precision = precision_score(all_labels, all_preds, average='weighted')
        val_recall = recall_score(all_labels, all_preds, average='weighted')
        val_f1 = f1_score(all_labels, all_preds, average='weighted')

        # Save the best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_accuracy': val_accuracy,
            }, CONFIG['checkpoint_path'])
            print(f"\nBest model saved with validation accuracy: {best_val_accuracy:.4f}")

        # Store validation metrics
        val_losses.append(epoch_val_loss)
        val_accuracies.append(val_accuracy)
        val_precisions.append(val_precision)
        val_recalls.append(val_recall)
        val_f1_scores.append(val_f1)

        # Print epoch summary
        print(
            f"Epoch : {epoch+1} - "
            f"train_loss : {epoch_loss:.4f} - train_acc: {epoch_accuracy:.4f} - "
            f"val_loss : {epoch_val_loss:.4f} - val_acc: {val_accuracy:.4f}\n"
            f"Validation Metrics: "
            f"Precision: {val_precision:.4f}, "
            f"Recall: {val_recall:.4f}, "
            f"F1-Score: {val_f1:.4f}\n"
        )

    # Optional: Plot training and validation metrics
    plt.figure(figsize=(12, 4))
    plt.subplot(131)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Loss Curves')
    plt.legend()

    plt.subplot(132)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Val Accuracy')
    plt.title('Accuracy Curves')
    plt.legend()

    plt.subplot(133)
    plt.plot(val_precisions, label='Precision')
    plt.plot(val_recalls, label='Recall')
    plt.plot(val_f1_scores, label='F1-Score')
    plt.title('Validation Metrics')
    plt.legend()

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

Using device: cuda


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/30 - Training: 100%|███████████████████████████████████████████████████████████| 81/81 [05:05<00:00,  3.77s/it]
Epoch 1/30 - Validation: 100%|█████████████████████████████████████████████████████████| 19/19 [00:16<00:00,  1.15it/s]



Best model saved with validation accuracy: 0.7567
Epoch : 1 - train_loss : 24.5352 - train_acc: 0.6356 - val_loss : 0.5583 - val_acc: 0.7567
Validation Metrics: Precision: 0.7567, Recall: 0.7567, F1-Score: 0.7567



Epoch 2/30 - Training: 100%|███████████████████████████████████████████████████████████| 81/81 [05:01<00:00,  3.72s/it]
Epoch 2/30 - Validation: 100%|█████████████████████████████████████████████████████████| 19/19 [00:16<00:00,  1.17it/s]



Best model saved with validation accuracy: 0.8100
Epoch : 2 - train_loss : 24.0880 - train_acc: 0.8424 - val_loss : 0.4640 - val_acc: 0.8100
Validation Metrics: Precision: 0.8204, Recall: 0.8100, F1-Score: 0.8084



Epoch 3/30 - Training:  90%|█████████████████████████████████████████████████████▏     | 73/81 [04:39<00:30,  3.83s/it]


KeyboardInterrupt: 