In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import copy

# Configuration
TRAIN_IMG_DIR = '/media/prashant/12 TB HDD 21/cancer_portal/lung_cancer/latest_multiclass/train'
VAL_IMG_DIR = '/media/prashant/12 TB HDD 21/cancer_portal/lung_cancer/latest_multiclass/valid'
TRAIN_CSV = '/media/prashant/12 TB HDD 21/cancer_portal/lung_cancer/latest_multiclass/train.csv'
VAL_CSV = '/media/prashant/12 TB HDD 21/cancer_portal/lung_cancer/latest_multiclass/val.csv'
BATCH_SIZE = 128
NUM_EPOCHS = 200
TRAIN_AUGMENT_FACTOR = 5  # How many augmented versions to create per original image
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}, Epochs: {NUM_EPOCHS}")
print(f"Augmentation factor: {TRAIN_AUGMENT_FACTOR}x\n")

# Custom Dataset Class with verification and augmentation
class AugmentedLungCancerDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None, augment=False):
        # Read and verify CSV
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")
        self.df = pd.read_csv(csv_path)
        
        # Verify CSV structure
        required_columns = ['filename', 'Nodule']
        for col in required_columns:
            if col not in self.df.columns:
                raise ValueError(f"CSV missing required column: '{col}'")
        
        # Verify labels
        invalid_labels = self.df[~self.df['Nodule'].isin([0, 1])]
        if not invalid_labels.empty:
            raise ValueError(f"Invalid labels found in CSV:\n{invalid_labels}")
        
        # Verify image directory
        if not os.path.exists(img_dir):
            raise NotADirectoryError(f"Image directory not found: {img_dir}")
        
        self.img_dir = img_dir
        self.transform = transform
        self.augment = augment
        self.original_size = len(self.df)
        
        # Print dataset stats
        num_pos = (self.df['Nodule'] == 1).sum()
        num_neg = (self.df['Nodule'] == 0).sum()
        print(f"Dataset: {os.path.basename(csv_path)}")
        print(f"  Original samples: {self.original_size}")
        print(f"  Positive (cancer): {num_pos} ({num_pos/self.original_size*100:.1f}%)")
        print(f"  Negative (healthy): {num_neg} ({num_neg/self.original_size*100:.1f}%)")
        
        # Verify image files
        missing_files = []
        for idx, row in self.df.iterrows():
            img_path = os.path.join(self.img_dir, row['filename'])
            if not os.path.exists(img_path):
                missing_files.append((idx, row['filename']))
        
        if missing_files:
            print("\nWARNING: Missing image files:")
            for idx, fname in missing_files[:5]:  # Show first 5 missing files
                print(f"  Row {idx}: {fname}")
            if len(missing_files) > 5:
                print(f"  ... and {len(missing_files)-5} more")
            raise FileNotFoundError(f"Total missing images: {len(missing_files)}")
        else:
            print("  All image files verified successfully")
        
    def __len__(self):
        if self.augment:
            return self.original_size * TRAIN_AUGMENT_FACTOR
        return self.original_size
    
    def __getitem__(self, idx):
        # Calculate original index for non-augmented dataset
        orig_idx = idx % self.original_size
        
        img_name = self.df.iloc[orig_idx]['filename']
        label = self.df.iloc[orig_idx]['Nodule']
        img_path = os.path.join(self.img_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"\nERROR loading image: {img_path}")
            print(f"Exception: {str(e)}")
            # Return blank image as fallback
            image = Image.new('RGB', (224, 224), color='black')
            return image, torch.tensor(-1, dtype=torch.float32)  # Invalid label
        
        if self.transform:
            # Apply different augmentation each time for augmented dataset
            if self.augment:
                # Create a new random transform for each augmentation
                aug_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomRotation(10),
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
                image = aug_transform(image)
            else:
                image = self.transform(image)
            
        return image, torch.tensor(label, dtype=torch.float32)

# Image Transformations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create Datasets and DataLoaders with verification
print("\n" + "="*50)
print("Creating training dataset with augmentation:")
train_dataset = AugmentedLungCancerDataset(TRAIN_CSV, TRAIN_IMG_DIR, train_transform, augment=True)
print(f"  Augmented training size: {len(train_dataset)} images ({TRAIN_AUGMENT_FACTOR}x original)")

print("\n" + "="*50)
print("Creating validation dataset:")
val_dataset = AugmentedLungCancerDataset(VAL_CSV, VAL_IMG_DIR, val_transform, augment=False)
print(f"  Validation size: {len(val_dataset)} images")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print("\n" + "="*50)
print(f"Training batches: {len(train_loader)} ({len(train_dataset)} images)")
print(f"Validation batches: {len(val_loader)} ({len(val_dataset)} images)")

# Initialize Model
print("\n" + "="*50)
print("Initializing ResNet50 model...")
model = models.resnet50(weights='DEFAULT')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model = model.to(DEVICE)

# Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)

print(f"\nModel architecture:")
print(model)
print(f"\nTrainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Function to visualize augmented images
def visualize_augmentations(dataset, num_samples=5):
    print("\nVisualizing augmented images...")
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples*3))
    fig.suptitle('Image Augmentation Examples', fontsize=16)
    
    for i in range(num_samples):
        # Show original image
        orig_idx = np.random.randint(0, dataset.original_size)
        img_name = dataset.df.iloc[orig_idx]['filename']
        img_path = os.path.join(dataset.img_dir, img_name)
        orig_image = Image.open(img_path).convert('RGB')
        
        # Show two augmented versions
        aug_idx1 = orig_idx * TRAIN_AUGMENT_FACTOR
        aug_idx2 = orig_idx * TRAIN_AUGMENT_FACTOR + 1
        
        # Get augmented images directly from dataset
        aug_image1, _ = dataset[aug_idx1]
        aug_image2, _ = dataset[aug_idx2]
        
        # Convert tensors to numpy and denormalize
        def denormalize(tensor):
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            tensor = tensor * std + mean
            tensor = tensor.clamp(0, 1)
            return tensor.permute(1, 2, 0).numpy()
        
        # Plot images
        axes[i, 0].imshow(orig_image)
        axes[i, 0].set_title(f"Original\n{img_name}")
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(denormalize(aug_image1))
        axes[i, 1].set_title("Augmentation 1")
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(denormalize(aug_image2))
        axes[i, 2].set_title("Augmentation 2")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('augmentation_samples.png', dpi=150)
    plt.close()
    print("Saved augmentation_samples.png")

# Visualize augmentations
visualize_augmentations(train_dataset)

# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    start_time = time.time()
    
    print("\n" + "="*50)
    print("Starting training...")
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        # Training Phase
        model.train()
        running_loss = 0.0
        processed = 0
        
        for inputs, labels in train_loader:
            # Skip batches with invalid labels
            if (labels < 0).any():
                print("Skipping batch with invalid labels")
                continue
                
            inputs = inputs.to(DEVICE)
            labels = labels.unsqueeze(1).to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            processed += inputs.size(0)
            
        epoch_loss = running_loss / processed
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                # Skip batches with invalid labels
                if (labels < 0).any():
                    print("Skipping batch with invalid labels")
                    continue
                    
                inputs = inputs.to(DEVICE)
                labels = labels.unsqueeze(1).to(DEVICE)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                
                preds = torch.sigmoid(outputs) > 0.5
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        val_loss = val_loss / total if total > 0 else 0
        val_acc = correct / total if total > 0 else 0
        
        # Update history
        history['train_loss'].append(epoch_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Update learning rate
        scheduler.step(val_acc)
        
        epoch_time = time.time() - epoch_start
        print(f"\nEpoch {epoch+1}/{num_epochs} ({epoch_time:.1f}s)")
        print(f"  Train Loss: {epoch_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, 'best_resnet50_augmented_model.pth')
            print(f"  Saved new best model with accuracy: {best_acc:.4f}")
    
    total_time = time.time() - start_time
    print(f"\nTraining complete in {total_time//60:.0f}m {total_time%60:.0f}s")
    print(f"Best Validation Accuracy: {best_acc:.4f}")
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

# Start training
trained_model, history = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS)

# Final summary
print("\n" + "="*50)
print("Training Summary:")
print(f"Original training samples: {train_dataset.original_size}")
print(f"Augmented training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Final Validation Accuracy: {history['val_acc'][-1]:.4f}")
print(f"Best Validation Accuracy: {max(history['val_acc']):.4f}")
print(f"Model saved as: 'best_resnet50_augmented_model.pth'")

# Plot training history
def plot_history(history):
    plt.figure(figsize=(12, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(history['val_acc'], label='Validation Accuracy', color='green')
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150)
    plt.close()
    print("Saved training_history.png")

plot_history(history)

Using device: cuda
Batch size: 128, Epochs: 200
Augmentation factor: 5x


Creating training dataset with augmentation:
Dataset: train.csv
  Original samples: 728
  Positive (cancer): 393 (54.0%)
  Negative (healthy): 335 (46.0%)
  All image files verified successfully
  Augmented training size: 3640 images (5x original)

Creating validation dataset:
Dataset: val.csv
  Original samples: 208
  Positive (cancer): 116 (55.8%)
  Negative (healthy): 92 (44.2%)
  All image files verified successfully
  Validation size: 208 images

Training batches: 29 (3640 images)
Validation batches: 2 (208 images)

Initializing ResNet50 model...





Model architecture:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_siz

KeyboardInterrupt: 