In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# CBAMWDnet TB Detection - Full Pipeline with Model Comparison

import os
import shutil
import random
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

# === Step 1: Prepare Dataset ===

def prepare_dataset():
    MONT_PATH = "/content/drive/MyDrive/DataSet/MontgomerySet/CXR_png"
    SHENZHEN_PATH = "/content/drive/MyDrive/DataSet/ChinaSet_AllFiles/CXR_png"
    TBX11K_PATH = "/content/drive/MyDrive/TB_Dataset/TBnNormal"

    OUTPUT_PATH = "/content/filtered_dataset"
    TB_DIR = os.path.join(OUTPUT_PATH, "TB")
    NORMAL_DIR = os.path.join(OUTPUT_PATH, "Normal")

    os.makedirs(TB_DIR, exist_ok=True)
    os.makedirs(NORMAL_DIR, exist_ok=True)

    def copy_resize(img_path, label, count):
        target_dir = TB_DIR if label == "TB" else NORMAL_DIR
        try:
            img = Image.open(img_path)
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img = img.resize((224, 224))
            img.save(os.path.join(target_dir, f"{label}_{count}.jpg"), format='JPEG')
        except Exception as e:
            print(f"Skipped {img_path}: {e}")

    image_records = []

    for img in Path(MONT_PATH).rglob("*.png"):
        name = img.name
        if "_0.png" in name:
            label = "Normal"
        elif "_1.png" in name:
            label = "TB"
        else:
            continue
        image_records.append((img, label))

    for img in Path(SHENZHEN_PATH).rglob("*.png"):
        name = img.name
        if "_0.png" in name:
            label = "Normal"
        elif "_1.png" in name:
            label = "TB"
        else:
            continue
        image_records.append((img, label))

    tbx_tb_path = os.path.join(TBX11K_PATH, "PULMONARY_TUBERCULOSIS")
    tbx_normal_path = os.path.join(TBX11K_PATH, "NORMAL")

    for img in Path(tbx_tb_path).rglob("*.jpg"):
        image_records.append((img, "TB"))
    for img in Path(tbx_normal_path).rglob("*.jpg"):
        image_records.append((img, "Normal"))

    random.shuffle(image_records)
    tb_images = [img for img in image_records if img[1] == "TB"][:2500]
    normal_images = [img for img in image_records if img[1] == "Normal"][:2500]

    for i, (img_path, label) in enumerate(tb_images + normal_images):
        copy_resize(img_path, label, i)

# === Step 2: Dataset Loader ===

def get_dataloaders(data_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

    dataset = datasets.ImageFolder(data_dir, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# === Step 3: CBAMWDnet Model ===

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        max_, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg, max_], dim=1)
        return self.sigmoid(self.conv(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

class CBAMWDnet(nn.Module):
    def __init__(self, num_classes=2):
        super(CBAMWDnet, self).__init__()
        base_model = models.densenet121(pretrained=True)
        self.features = base_model.features  # no conv0 modification
        self.cbam = CBAM(1024)               # Apply CBAM at the end of DenseNet features
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )


    def forward(self, x):
        x = self.features(x)
        x = self.cbam(x)
        return self.classifier(x)

# === Step 4: Train, Evaluate, and Compare ===

def train_model(model, train_loader, test_loader, device, model_name):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    train_losses = []

    for epoch in range(10):
        model.train()
        total_loss = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)
        print(f"{model_name} - Epoch {epoch+1}, Loss: {avg_loss:.4f}")

    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            preds = torch.argmax(outputs, dim=1).cpu()
            all_preds.extend(preds)
            all_labels.extend(labels)

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    # Save confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Normal", "TB"], yticklabels=["Normal", "TB"])
    plt.title(f"Confusion Matrix - {model_name}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(f"{model_name}_confusion_matrix.png")
    plt.close()

    # Plot training loss
    plt.figure()
    plt.plot(train_losses, 'bo-', label='Loss')
    plt.title(f"Training Loss - {model_name}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{model_name}_training_loss.png")
    plt.close()

    return model_name, acc, precision, recall, f1

def train_and_compare():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = get_dataloaders("/content/filtered_dataset", batch_size=32)

    results = []

    model = CBAMWDnet(num_classes=2).to(device)
    results.append(train_model(model, train_loader, test_loader, device, "CBAMWDnet"))

    base_model = models.densenet121(pretrained=True)
    base_model.classifier = nn.Linear(1024, 2)
    results.append(train_model(base_model.to(device), train_loader, test_loader, device, "DenseNet121"))

    base_model = models.resnet50(pretrained=True)
    base_model.fc = nn.Linear(2048, 2)
    results.append(train_model(base_model.to(device), train_loader, test_loader, device, "ResNet50"))

    base_model = models.vgg16(pretrained=True)
    base_model.classifier[6] = nn.Linear(4096, 2)
    results.append(train_model(base_model.to(device), train_loader, test_loader, device, "VGG16"))

    print("\n=== Model Comparison ===")
    print("{:<12} {:<10} {:<10} {:<10} {:<10}".format("Model", "Accuracy", "Precision", "Recall", "F1"))
    for name, acc, prec, rec, f1 in results:
        print(f"{name:<12} {acc*100:<10.2f} {prec:<10.2f} {rec:<10.2f} {f1:<10.2f}")

# === MAIN ===

if __name__ == "__main__":
    print("Preparing dataset...")
    prepare_dataset()
    print("T                                                                                                                                                                                                                                                      raining and comparing models...")
    train_and_compare()


Preparing dataset...
Training and comparing models...


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 75.9MB/s]


CBAMWDnet - Epoch 1, Loss: 0.6192
CBAMWDnet - Epoch 2, Loss: 0.4057
CBAMWDnet - Epoch 3, Loss: 0.2577
CBAMWDnet - Epoch 4, Loss: 0.1745
CBAMWDnet - Epoch 5, Loss: 0.1013
CBAMWDnet - Epoch 6, Loss: 0.1152
CBAMWDnet - Epoch 7, Loss: 0.0680
CBAMWDnet - Epoch 8, Loss: 0.0248
CBAMWDnet - Epoch 9, Loss: 0.0362
CBAMWDnet - Epoch 10, Loss: 0.0341




DenseNet121 - Epoch 1, Loss: 0.5247
DenseNet121 - Epoch 2, Loss: 0.3073
DenseNet121 - Epoch 3, Loss: 0.1944
DenseNet121 - Epoch 4, Loss: 0.1354
DenseNet121 - Epoch 5, Loss: 0.0566
DenseNet121 - Epoch 6, Loss: 0.0554
DenseNet121 - Epoch 7, Loss: 0.0504
DenseNet121 - Epoch 8, Loss: 0.0409
DenseNet121 - Epoch 9, Loss: 0.0184
DenseNet121 - Epoch 10, Loss: 0.0180


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 153MB/s]


ResNet50 - Epoch 1, Loss: 0.4635
ResNet50 - Epoch 2, Loss: 0.2275
ResNet50 - Epoch 3, Loss: 0.1178
ResNet50 - Epoch 4, Loss: 0.1010
ResNet50 - Epoch 5, Loss: 0.0636
ResNet50 - Epoch 6, Loss: 0.0544
ResNet50 - Epoch 7, Loss: 0.0516
ResNet50 - Epoch 8, Loss: 0.0304
ResNet50 - Epoch 9, Loss: 0.0417
ResNet50 - Epoch 10, Loss: 0.0189


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:07<00:00, 70.6MB/s]


VGG16 - Epoch 1, Loss: 0.5287
VGG16 - Epoch 2, Loss: 0.3789
VGG16 - Epoch 3, Loss: 0.3270
VGG16 - Epoch 4, Loss: 0.2521
VGG16 - Epoch 5, Loss: 0.2334
VGG16 - Epoch 6, Loss: 0.1955
VGG16 - Epoch 7, Loss: 0.1770
VGG16 - Epoch 8, Loss: 0.0865
VGG16 - Epoch 9, Loss: 0.1164
VGG16 - Epoch 10, Loss: 0.0999

=== Model Comparison ===
Model        Accuracy   Precision  Recall     F1        
CBAMWDnet    86.88      0.90       0.79       0.84      
DenseNet121  86.25      0.84       0.86       0.85      
ResNet50     86.88      0.92       0.78       0.84      
VGG16        83.75      0.80       0.85       0.82      
