In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import json
from datetime import datetime

class MedicalImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

class MedicalClassifier:
    def __init__(self, num_classes=3, model_name='efficientnetv2_s', device='cuda'):
        self.device = device
        self.num_classes = num_classes
        self.model_name = model_name
        self.best_val_acc = 0.0
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        
        # Set directories for saving models and logs (use Kaggle's output directory)
        self.checkpoint_dir = '/kaggle/working/checkpoints'
        self.logs_dir = '/kaggle/working/logs'
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.logs_dir, exist_ok=True)
        
        # Initialize model
        self.model = self._initialize_model()
        self.model = self.model.to(self.device)
        
        # Define transforms
        # Progressive resizing: Start with smaller size, increase during training
        self.transform_stage1 = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.transform_stage2 = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def _initialize_model(self):
        model = timm.create_model(self.model_name, pretrained=False, num_classes=self.num_classes)
        return model
    
    def load_data(self, data_dir):
        """Load and split data from directory structure data_dir/class/image.jpg"""
        image_paths = []
        labels = []
        class_to_idx = {}
        
        for idx, class_name in enumerate(os.listdir(data_dir)):
            class_dir = os.path.join(data_dir, class_name)
            class_to_idx[class_name] = idx
            
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        image_paths.append(os.path.join(class_dir, img_name))
                        labels.append(idx)
        
        # Save class mapping
        with open(os.path.join(self.logs_dir, 'class_mapping.json'), 'w') as f:
            json.dump(class_to_idx, f)
        
        # Split data
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            image_paths, labels, test_size=0.2, random_state=42, stratify=labels
        )
        
        val_paths, test_paths, val_labels, test_labels = train_test_split(
            val_paths, val_labels, test_size=0.5, random_state=42, stratify=val_labels
        )
        
        # Create datasets
        train_dataset = MedicalImageDataset(train_paths, train_labels, self.transform_stage1)
        val_dataset = MedicalImageDataset(val_paths, val_labels, self.val_transform)
        test_dataset = MedicalImageDataset(test_paths, test_labels, self.val_transform)
        
        return train_dataset, val_dataset, test_dataset
    
    def train(self, train_dataset, val_dataset, batch_size=32, num_epochs_stage1=20, num_epochs_stage2=30,
              learning_rate=1e-4, resume_training=False):
        """Train the model with checkpointing and progressive resizing"""
        # Stage 1: Train with smaller images (128x128)
        train_loader_stage1 = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.1)
        
        # Load previous checkpoint if resuming
        start_epoch = 0
        if resume_training:
            checkpoint_path = os.path.join(self.checkpoint_dir, 'latest_checkpoint.pth')
            if os.path.exists(checkpoint_path):
                checkpoint = torch.load(checkpoint_path)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch'] + 1
                self.best_val_acc = checkpoint['best_val_acc']
                print(f"Resuming training from epoch {start_epoch}")
        
        # Training loop
        for epoch in range(start_epoch, num_epochs_stage1):
            self.model.train()
            train_loss = 0
            correct = 0
            total = 0
            
            for images, labels in tqdm(train_loader_stage1, desc=f'Epoch {epoch+1}/{num_epochs_stage1} [Train Stage 1]'):
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            
            epoch_train_loss = train_loss / len(train_loader_stage1)
            epoch_train_acc = 100. * correct / total
            
            scheduler.step(epoch_train_acc)  # Adjust learning rate
        
        # Stage 2: Train with higher-resolution images (224x224)
        train_dataset.transform = self.transform_stage2
        train_loader_stage2 = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        
        for epoch in range(num_epochs_stage1, num_epochs_stage1 + num_epochs_stage2):
            self.model.train()
            train_loss = 0
            correct = 0
            total = 0
            
            for images, labels in tqdm(train_loader_stage2, desc=f'Epoch {epoch+1}/{num_epochs_stage2} [Train Stage 2]'):
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            
            epoch_train_loss = train_loss / len(train_loader_stage2)
            epoch_train_acc = 100. * correct / total
            
            print(f'Epoch {epoch+1}/{num_epochs_stage1} - Loss: {epoch_train_loss:.4f}, Accuracy: {epoch_train_acc:.2f}%')
            
            scheduler.step(epoch_train_acc)  # Adjust learning rate
                

        # Save best model
        torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, 'best_model.pth'))
    
    def evaluate(self, test_dataset, batch_size=32):
        """Evaluate the model on test set"""
        self.model.eval()
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc='Testing'):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                _, predicted = outputs.max(1)
                
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        accuracy = 100. * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')
        
        return accuracy

# Usage example
if __name__ == "__main__":
    classifier = MedicalClassifier(num_classes=3, model_name='efficientnetv2_s',
                                   device='cuda' if torch.cuda.is_available() else 'cpu')
    
    data_dir = '/kaggle/input/medical-image-classification/dataset2'  # Update with your Kaggle dataset path
    train_dataset, val_dataset, test_dataset = classifier.load_data(data_dir)
    
    classifier.train(train_dataset, val_dataset, batch_size=32, num_epochs_stage1=10, num_epochs_stage2=15)
    classifier.evaluate(test_dataset)



Epoch 1/10 [Train Stage 1]: 100%|██████████| 979/979 [02:21<00:00,  6.91it/s]
Epoch 2/10 [Train Stage 1]: 100%|██████████| 979/979 [02:03<00:00,  7.91it/s]
Epoch 3/10 [Train Stage 1]: 100%|██████████| 979/979 [02:04<00:00,  7.87it/s]
Epoch 4/10 [Train Stage 1]: 100%|██████████| 979/979 [02:04<00:00,  7.88it/s]
Epoch 5/10 [Train Stage 1]: 100%|██████████| 979/979 [02:03<00:00,  7.92it/s]
Epoch 6/10 [Train Stage 1]: 100%|██████████| 979/979 [02:02<00:00,  7.97it/s]
Epoch 7/10 [Train Stage 1]: 100%|██████████| 979/979 [02:01<00:00,  8.08it/s]
Epoch 8/10 [Train Stage 1]: 100%|██████████| 979/979 [02:03<00:00,  7.93it/s]
Epoch 9/10 [Train Stage 1]: 100%|██████████| 979/979 [02:01<00:00,  8.04it/s]
Epoch 10/10 [Train Stage 1]: 100%|██████████| 979/979 [02:01<00:00,  8.05it/s]
Epoch 11/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 11/10 - Loss: 0.0643, Accuracy: 98.04%


Epoch 12/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 12/10 - Loss: 0.0091, Accuracy: 99.65%


Epoch 13/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 13/10 - Loss: 0.0137, Accuracy: 99.58%


Epoch 14/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 14/10 - Loss: 0.0093, Accuracy: 99.70%


Epoch 15/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 15/10 - Loss: 0.0091, Accuracy: 99.74%


Epoch 16/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.31it/s]


Epoch 16/10 - Loss: 0.0304, Accuracy: 99.54%


Epoch 17/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 17/10 - Loss: 0.0162, Accuracy: 99.52%


Epoch 18/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 18/10 - Loss: 0.0104, Accuracy: 99.81%


Epoch 19/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 19/10 - Loss: 0.0116, Accuracy: 99.62%


Epoch 20/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 20/10 - Loss: 0.0029, Accuracy: 99.90%


Epoch 21/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.33it/s]


Epoch 21/10 - Loss: 0.0061, Accuracy: 99.82%


Epoch 22/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.31it/s]


Epoch 22/10 - Loss: 0.0057, Accuracy: 99.83%


Epoch 23/15 [Train Stage 2]: 100%|██████████| 979/979 [03:47<00:00,  4.31it/s]


Epoch 23/10 - Loss: 0.0049, Accuracy: 99.86%


Epoch 24/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 24/10 - Loss: 0.0040, Accuracy: 99.89%


Epoch 25/15 [Train Stage 2]: 100%|██████████| 979/979 [03:46<00:00,  4.32it/s]


Epoch 25/10 - Loss: 0.0224, Accuracy: 99.71%


Testing: 100%|██████████| 123/123 [00:19<00:00,  6.47it/s]

Test Accuracy: 100.00%





Testing: 100%|██████████| 123/123 [00:10<00:00, 12.17it/s]

Test Accuracy: 100.00%



