In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from natsort import natsorted
from sklearn.metrics import classification_report, confusion_matrix

# Configuration
IMG_SIZE = (256, 256)
BATCH_SIZE = 16
EPOCHS = 100
VAL_SPLIT = 0.2
DATA_PATH = "output_images/"
LABELS = ["0", "1"]  # Benign (0) and Malignant (1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Enhanced Dataset Class with proper transform handling
class UltrasoundDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.rf_images = []
        self.labels = []
        self.transform = transform
        valid_extensions = ('.png', '.jpg', '.jpeg')

        for label in LABELS:
            rf_path = os.path.join(data_path, label, "rf")
            if not os.path.exists(rf_path):
                continue
                
            rf_files = natsorted([f for f in os.listdir(rf_path) if f.lower().endswith(valid_extensions)])

            for rf_file in rf_files:
                rf_img = cv2.imread(os.path.join(rf_path, rf_file), cv2.IMREAD_GRAYSCALE)
                
                if rf_img is None:
                    continue
                
                rf_img = cv2.resize(rf_img, IMG_SIZE)
                self.rf_images.append(rf_img)
                self.labels.append(int(label))

        print(f"Loaded {len(self.rf_images)} images")
        print(f"Class distribution: {np.bincount(self.labels)}")

    def __len__(self):
        return len(self.rf_images)

    def __getitem__(self, idx):
        img = self.rf_images[idx]
        
        # Convert numpy array to PIL Image only if transforms are specified
        if self.transform:
            img = Image.fromarray(img)
            img = self.transform(img)
        else:
            img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0
            
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label

# Enhanced Model Architecture
class DenseCNN(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, base_channels, 3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1),
            nn.BatchNorm2d(base_channels*8),
            nn.ReLU(),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.2),
            nn.Linear(base_channels*8, 2)
        )
        
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# Training Function with Early Stopping
def train_model():
    # Data augmentation and normalization
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    # Load dataset
    full_dataset = UltrasoundDataset(DATA_PATH)
    
    # Split dataset
    train_size = int((1 - VAL_SPLIT) * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Apply transforms to subsets
    train_dataset.dataset.transform = train_transform
    val_dataset.dataset.transform = val_transform
    
    # Handle class imbalance
    labels = [full_dataset.labels[i] for i in train_dataset.indices]
    class_counts = torch.bincount(torch.tensor(labels))
    class_weights = 1. / class_counts.float()
    sample_weights = class_weights[labels]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = DenseCNN().to(device)
    
    # Loss function with class weighting
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    
    # Training loop
    best_val_acc = 0
    patience = 30
    no_improve = 0
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss, correct, total = 0, 0, 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_acc = correct / total
        
        # Validation
        model.eval()
        val_loss, correct, total = 0, 0, 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_acc = correct / total
        scheduler.step(val_loss)
        
        # Print metrics
        print(f"Epoch {epoch+1}/{EPOCHS} | "
              f"Train Loss: {train_loss/len(train_loader):.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss/len(val_loader):.4f} | "
              f"Val Acc: {val_acc:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve = 0
            torch.save(model.state_dict(), 'best_model_des.pth')
            print("Saved new best model")
            
            # Print classification report
            print("\nClassification Report:")
            print(classification_report(all_labels, all_preds, target_names=LABELS))
            print("Confusion Matrix:")
            print(confusion_matrix(all_labels, all_preds))
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
    
    print(f"\nTraining complete. Best validation accuracy: {best_val_acc:.4f}")

if __name__ == "__main__":
    train_model()

Using device: cuda
Loaded 200 images
Class distribution: [ 96 104]
Epoch 1/100 | Train Loss: 0.5836 | Train Acc: 0.6562 | Val Loss: 0.5170 | Val Acc: 0.8000 | LR: 1.00e-03
Saved new best model

Classification Report:
              precision    recall  f1-score   support

           0       0.76      0.90      0.83        21
           1       0.87      0.68      0.76        19

    accuracy                           0.80        40
   macro avg       0.81      0.79      0.80        40
weighted avg       0.81      0.80      0.80        40

Confusion Matrix:
[[19  2]
 [ 6 13]]
Epoch 2/100 | Train Loss: 0.6181 | Train Acc: 0.6312 | Val Loss: 1.6074 | Val Acc: 0.5250 | LR: 1.00e-03
Epoch 3/100 | Train Loss: 0.5451 | Train Acc: 0.7063 | Val Loss: 2.3512 | Val Acc: 0.5250 | LR: 1.00e-03
Epoch 4/100 | Train Loss: 0.4978 | Train Acc: 0.7562 | Val Loss: 0.8309 | Val Acc: 0.6000 | LR: 1.00e-03
Epoch 5/100 | Train Loss: 0.5115 | Train Acc: 0.7250 | Val Loss: 2.2023 | Val Acc: 0.5500 | LR: 5.00e-04