In [None]:
import os
import cv2
import numpy as np
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T

# Define dataset directories
DATASET_DIR = os.path.join('.', 'tuberculosis-dataset')
NORMAL_DIR  = os.path.join(DATASET_DIR, 'Normal')
TB_DIR      = os.path.join(DATASET_DIR, 'Tuberculosis')

print("NORMAL_DIR exists?", os.path.exists(NORMAL_DIR))
print("TB_DIR exists?", os.path.exists(TB_DIR))


NORMAL_DIR exists? True
TB_DIR exists? True


In [None]:
all_images = []
all_labels = []
IMAGE_SIZE = 256

# Load Normal images (label = 0)
for fname in os.listdir(NORMAL_DIR):
    filepath = os.path.join(NORMAL_DIR, fname)
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img is not None:
        img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
        all_images.append(img)
        all_labels.append(0)

# Load Tuberculosis images (label = 1)
for fname in os.listdir(TB_DIR):
    filepath = os.path.join(TB_DIR, fname)
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img is not None:
        img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
        all_images.append(img)
        all_labels.append(1)

all_images = np.array(all_images)
all_labels = np.array(all_labels)

print("Images shape:", all_images.shape)
print("Labels shape:", all_labels.shape)


Images shape: (4200, 256, 256)
Labels shape: (4200,)


In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    all_images, all_labels, test_size=0.3, random_state=42
)

# Convert from [0,255] -> [0,1] for training
X_train = X_train.astype('float32') / 255.0
X_test  = X_test.astype('float32')  / 255.0

print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)


X_train shape: (2940, 256, 256)
X_test shape: (1260, 256, 256)


In [None]:
smote_engine = SMOTE(random_state=42)

# Flatten images for SMOTE (shape: (n_samples, 256*256))
num_train_samples = X_train.shape[0]
X_train_2D = X_train.reshape(num_train_samples, -1)

X_train_resampled_2D, y_train_resampled = smote_engine.fit_resample(X_train_2D, y_train)

# Reshape back to (N, H, W, 1)
X_train_resampled = X_train_resampled_2D.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 1)

print("After SMOTE, X_train_resampled shape:", X_train_resampled.shape)
print("After SMOTE, y_train_resampled shape:", y_train_resampled.shape)
unique_vals, counts = np.unique(y_train_resampled, return_counts=True)
print("Label distribution:", dict(zip(unique_vals, counts)))


After SMOTE, X_train_resampled shape: (4914, 256, 256, 1)
After SMOTE, y_train_resampled shape: (4914,)
Label distribution: {np.int64(0): np.int64(2457), np.int64(1): np.int64(2457)}


In [None]:
class TBChestXrayDataset(Dataset):
    """
    Industry-level naming for clarity:
    A PyTorch Dataset class for Tuberculosis Chest X-ray classification,
    capable of applying transforms (data augmentation).
    """
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (np.array): (N, H, W, 1) or (N, H, W) array of grayscale images.
            labels (np.array): (N,) array of labels (0 or 1).
            transform (callable, optional): A torchvision-like transform pipeline.
        """
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Extract image, label
        raw_image = self.images[idx]  # shape: (H, W, 1) or (H, W)
        label_val = self.labels[idx].astype('float32')

        # If shape is (H, W, 1), we can squeeze to (H, W)
        if len(raw_image.shape) == 3 and raw_image.shape[2] == 1:
            raw_image = np.squeeze(raw_image, axis=-1)

        # Convert to PIL Image for transforms
        pil_image = Image.fromarray((raw_image * 255).astype('uint8'), mode='L')

        if self.transform:
            img_tensor = self.transform(pil_image)  # transforms yield torch.Tensor [C,H,W]
        else:
            # Default: turn into PyTorch tensor with shape [1,H,W]
            img_tensor = torch.tensor(raw_image, dtype=torch.float32).unsqueeze(0)

        label_tensor = torch.tensor(label_val, dtype=torch.float32)
        return img_tensor, label_tensor

# Define transforms for data augmentation
train_transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10),
    T.ToTensor(),  # This normalizes from [0,255] -> [0,1] as well
])

# For test set, just convert to tensor (no augmentation)
test_transforms = T.Compose([
    T.ToTensor()
])

# Create datasets
train_dataset = TBChestXrayDataset(X_train_resampled, y_train_resampled, transform=train_transforms)
test_4D = X_test.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 1)
test_dataset  = TBChestXrayDataset(test_4D, y_test, transform=test_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Number of training samples:", len(train_dataset))
print("Number of testing samples:", len(test_dataset))


Number of training samples: 4914
Number of testing samples: 1260


In [None]:
class SpatialAttention(nn.Module):
    """
    A simple spatial attention module inspired by CBAM:
    - Takes an input feature map [B,C,H,W].
    - Produces an attention map [B,1,H,W].
    - We'll multiply this attention map by the input feature map to emphasize important regions.
    """
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv2d = nn.Conv2d(2, 1, kernel_size=kernel_size, 
                                padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        # input_tensor shape: [B, C, H, W]
        # Compute average and max along the channel dimension
        avg_out = torch.mean(input_tensor, dim=1, keepdim=True)           # [B,1,H,W]
        max_out, _ = torch.max(input_tensor, dim=1, keepdim=True)         # [B,1,H,W]
        combined = torch.cat([avg_out, max_out], dim=1)                   # [B,2,H,W]

        # Pass through conv layer -> single-channel -> sigmoid
        attention_map = self.conv2d(combined)                             # [B,1,H,W]
        attention_map = self.sigmoid(attention_map)
        return attention_map

class SimpleAttentionCNN(nn.Module):
    """
    A variant of the SimpleCNN that includes a SpatialAttention block.
    """
    def __init__(self):
        super(SimpleAttentionCNN, self).__init__()

        # Feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # Spatial attention module
        self.spatial_attention = SpatialAttention(kernel_size=7)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 30 * 30, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Extract deep features
        features = self.feature_extractor(x)           # [B,64,30,30]
        attention_map = self.spatial_attention(features) # [B,1,30,30]

        # Apply attention
        attended_features = features * attention_map   # Element-wise multiplication

        # Classify
        output = self.classifier(attended_features)    # [B,1]
        return output, attention_map

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleAttentionCNN().to(device)

print(model)


SimpleAttentionCNN(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (spatial_attention): SpatialAttention(
    (conv2d): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (sigmoid): Sigmoid()
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=57600, out_features=64, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=64, out_features=1, bias=True)
    (5): Sigmoid()
  )
)


In [None]:
loss_function = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# We reduce LR on plateau of "accuracy" (maximizing). 
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=1, min_lr=1e-5, verbose=True
)




In [None]:
def train_one_epoch(network, dataloader, optimizer, criterion, device):
    network.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for batch_images, batch_labels in dataloader:
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device).view(-1, 1)  # shape: [B,1]
        
        optimizer.zero_grad()

        # Forward pass (we only need 'output' for loss)
        outputs, _ = network(batch_images)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch_images.size(0)
        
        predicted_classes = (outputs >= 0.5).float()  # threshold at 0.5
        correct_predictions += (predicted_classes == batch_labels).sum().item()
        total_samples += batch_labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def evaluate_model(network, dataloader, criterion, device):
    network.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch_images, batch_labels in dataloader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device).view(-1, 1)
            
            outputs, _ = network(batch_images)
            loss = criterion(outputs, batch_labels)
            
            total_loss += loss.item() * batch_images.size(0)
            
            predicted_classes = (outputs >= 0.5).float()
            correct_predictions += (predicted_classes == batch_labels).sum().item()
            total_samples += batch_labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

num_epochs = 10
best_accuracy = 0.0
best_model_path = "tb_chest_xray_attention_best.pt"

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, loss_function, device
    )
    
    # Using training accuracy for LR scheduling
    lr_scheduler.step(train_acc)
    
    print(f"[Epoch {epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    
    # Save the best model based on training accuracy
    if train_acc > best_accuracy:
        best_accuracy = train_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"Model improved at epoch {epoch+1}, saved to {best_model_path}")


[Epoch 1/10] Train Loss: 0.3811, Train Acc: 0.8319
Model improved at epoch 1, saved to tb_chest_xray_attention_best.pt
[Epoch 2/10] Train Loss: 0.2269, Train Acc: 0.9109
Model improved at epoch 2, saved to tb_chest_xray_attention_best.pt
[Epoch 3/10] Train Loss: 0.1895, Train Acc: 0.9239
Model improved at epoch 3, saved to tb_chest_xray_attention_best.pt
[Epoch 4/10] Train Loss: 0.1502, Train Acc: 0.9442
Model improved at epoch 4, saved to tb_chest_xray_attention_best.pt
[Epoch 5/10] Train Loss: 0.1358, Train Acc: 0.9491
Model improved at epoch 5, saved to tb_chest_xray_attention_best.pt
[Epoch 6/10] Train Loss: 0.1123, Train Acc: 0.9552
Model improved at epoch 6, saved to tb_chest_xray_attention_best.pt
[Epoch 7/10] Train Loss: 0.0961, Train Acc: 0.9638
Model improved at epoch 7, saved to tb_chest_xray_attention_best.pt
[Epoch 8/10] Train Loss: 0.1115, Train Acc: 0.9668
Model improved at epoch 8, saved to tb_chest_xray_attention_best.pt
[Epoch 9/10] Train Loss: 0.1059, Train Acc: 0.96

In [None]:
model.load_state_dict(torch.load(best_model_path))
model.eval()

all_predictions = []
all_ground_truths = []

with torch.no_grad():
    for images_batch, labels_batch in test_loader:
        images_batch = images_batch.to(device)
        labels_batch = labels_batch.to(device)
        
        outputs, attention_map = model(images_batch)
        preds = (outputs >= 0.5).float()
        
        all_predictions.extend(preds.cpu().numpy().flatten())
        all_ground_truths.extend(labels_batch.cpu().numpy().flatten())

print("CLASSIFICATION REPORT:")
print(classification_report(all_ground_truths, all_predictions, digits=4))

print("CONFUSION MATRIX:")
print(confusion_matrix(all_ground_truths, all_predictions))


  model.load_state_dict(torch.load(best_model_path))


CLASSIFICATION REPORT:
              precision    recall  f1-score   support

         0.0     0.9884    0.9779    0.9831      1043
         1.0     0.8991    0.9447    0.9213       217

    accuracy                         0.9722      1260
   macro avg     0.9437    0.9613    0.9522      1260
weighted avg     0.9730    0.9722    0.9725      1260

CONFUSION MATRIX:
[[1020   23]
 [  12  205]]
