In [2]:
pip install torch torchvision efficientnet-pytorch tqdm scikit-learn seaborn matplotlib

# Training

In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from efficientnet_pytorch import EfficientNet
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

class CTScanClassifier(nn.Module):
    def __init__(self, num_classes):
        super(CTScanClassifier, self).__init__()
        # Load pre-trained EfficientNetB0
        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
        # Modify the classifier
        num_ftrs = self.efficient_net._fc.in_features
        self.efficient_net._fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_ftrs, num_classes)
        )
    
    def forward(self, x):
        return self.efficient_net(x)

def create_dataloaders(data_path, batch_size=32):
    # Data transforms
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'valid': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    
    # Create datasets
    image_datasets = {
        x: ImageFolder(os.path.join(data_path, x), data_transforms[x])
        for x in ['train', 'valid', 'test']
    }
    
    # Create dataloaders
    dataloaders = {
        x: DataLoader(image_datasets[x], batch_size=batch_size,
                     shuffle=True if x == 'train' else False,
                     num_workers=4)
        for x in ['train', 'valid', 'test']
    }
    
    return dataloaders, image_datasets

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    best_acc = 0.0
    best_model_wts = None
    
    # For plotting
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accs.append(epoch_acc.cpu())
            else:
                val_losses.append(epoch_loss)
                val_accs.append(epoch_acc.cpu())
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict().copy()
                # Save checkpoint
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_acc': best_acc,
                }, f'checkpoint_epoch_{epoch}.pth')
    
    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses, train_accs, val_accs

def evaluate_model(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    model = model.to(device)
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, 
                              target_names=test_loader.dataset.classes))
    
    return cm, all_preds, all_labels

def plot_training_curves(train_losses, val_losses, train_accs, val_accs):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

def main():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Hyperparameters
    BATCH_SIZE = 32
    NUM_EPOCHS = 25
    LEARNING_RATE = 0.001
    
    # Create dataloaders
    dataloaders, image_datasets = create_dataloaders("/kaggle/input/chest-ctscan-images/Data", BATCH_SIZE)
    
    # Create model
    num_classes = len(image_datasets['train'].classes)
    model = CTScanClassifier(num_classes)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    
    # Train model
    model, train_losses, val_losses, train_accs, val_accs = train_model(
        model, dataloaders, criterion, optimizer, scheduler, NUM_EPOCHS
    )
    
    # Plot training curves
    plot_training_curves(train_losses, val_losses, train_accs, val_accs)
    
    # Evaluate on test set
    cm, preds, labels = evaluate_model(model, dataloaders['test'])
    
    # Plot confusion matrix
    plot_confusion_matrix(cm, image_datasets['test'].classes)

if __name__ == "__main__":
    main()

# Test Inferences

In [19]:
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
import matplotlib.pyplot as plt

class CTScanClassifier(nn.Module):
    def __init__(self, num_classes):
        super(CTScanClassifier, self).__init__()
        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
        num_ftrs = self.efficient_net._fc.in_features
        self.efficient_net._fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_ftrs, num_classes)
        )
    
    def forward(self, x):
        return self.efficient_net(x)

def load_model(checkpoint_path, num_classes):
    model = CTScanClassifier(num_classes)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def predict_and_plot_image(model, image_path, class_names):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    image_tensor = image_tensor.to(device)
    
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        probability, predicted_idx = torch.max(probabilities, 1)
    
    predicted_class = class_names[predicted_idx.item()]
    confidence = probability.item() * 100
    
    all_probabilities = probabilities[0].cpu().numpy() * 100
    class_probabilities = {class_name: prob for class_name, prob in zip(class_names, all_probabilities)}
    
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f'Prediction: {predicted_class}\nConfidence: {confidence:.2f}%')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    classes = list(class_probabilities.keys())
    probs = list(class_probabilities.values())
    
    bars = plt.bar(range(len(classes)), probs)
    plt.xlabel('Classes')
    plt.ylabel('Probability (%)')
    plt.title('Class Probabilities')
    plt.xticks(range(len(classes)), classes, rotation=45)
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%',
                ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    return predicted_class, confidence, class_probabilities

def main():
    checkpoint_path = '/kaggle/working/checkpoint_epoch_19.pth'
    test_image_path = '/kaggle/input/ct-test3/test3.jpg'
    class_names = ['adenocarcinoma', 'large.cell.carcinoma', 'normal', 'squamous.cell.carcinoma']
    
    model = load_model(checkpoint_path, len(class_names))
    predicted_class, confidence, class_probabilities = predict_and_plot_image(model, test_image_path, class_names)
    
    print(f"\nPredicted class: {predicted_class}")
    print(f"Confidence: {confidence:.2f}%")
    print("\nProbabilities for all classes:")
    for class_name, prob in class_probabilities.items():
        print(f"{class_name}: {prob:.2f}%")

if __name__ == "__main__":
    main()
