## 1) Import Libraries

In [None]:
import os
import cv2
import random
import torch
import warnings
import pandas as pd
import torch.nn as nn
import torch.optim
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import seaborn as sns
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torch.optim.lr_scheduler as scheduler

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2

In [None]:
warnings.filterwarnings('ignore')

In [None]:
sns.set_style("darkgrid")
sns.set_palette("RdBu")

## 2) Basic Config and Functions

In [None]:
def generate_df(directory: str, categories: list[str]) -> pd.DataFrame:
    data = []
    
    for category in categories:
        for path in os.listdir(os.path.join(directory, category)):        
            data.append([path, category])
        
    return pd.DataFrame(data, columns=['path', 'label'])

In [None]:
def read_img(path: str, size: tuple[int, int]):
    img = cv2.imread(img_path)
    
    return cv2.resize(img, size)

In [None]:
EPOCHS = 30
LEARNING_RATE = 3.5e-4
BATCH_SIZE = 32
LR_DECAY_EPOCH = [15, 30]
LR_DECAY = 0.1

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

IMAGE_SIZE = (224, 224)

CATEGORIES = {
    'PNEUMONIA': 1,
    'NORMAL': 0
}

BASE_PATH = '/kaggle/input/chest-xray-pneumonia/chest_xray'

TRAIN_PATH = f'{BASE_PATH}/train'
TEST_PATH = f'{BASE_PATH}/test'
VALID_PATH = f'{BASE_PATH}/valid'

train_df = generate_df(TRAIN_PATH, CATEGORIES.keys())

basic_transforms = A.Compose([
    A.Resize(
        height=IMAGE_SIZE[0],
        width=IMAGE_SIZE[1]
    ),
    # A.CLAHE(),
    A.Normalize(
        mean=(0.5,),
        std=(0.5,)
    ),
    ToTensorV2()
])

## 4) Load data

In [None]:
class ChestXRay(Dataset):
    def __init__(self, root, categories, transforms=None):
        self.root = root
        self.transforms = transforms
        self.images = []
        
        for category in categories:
            for img in os.listdir(os.path.join(self.root, category)):
                self.images.append({
                    'path': os.path.join(self.root, category, img),
                    'label': category
                })
                
    def __getitem__(self, item):
        item = self.images[item]
        
        path = item['path']
        label = item['label']
        
        img = np.array(Image.open(path).convert('L'))
        
        img = np.expand_dims(img, axis=-1)
        img = img.repeat(3, axis=-1)
        
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
            
        return img.float(), torch.tensor(CATEGORIES[label], dtype=torch.int64)
            
    def __len__(self):
        return len(self.images)

In [None]:
train_dataset = ChestXRay(
    root=TRAIN_PATH,
    categories=CATEGORIES,
    transforms=basic_transforms
)

test_dataset = ChestXRay(
    root=TEST_PATH,
    categories=CATEGORIES,
    transforms=basic_transforms
)

*Train*: Found **5216** images belonging to 2 classes.

*Test*: Found **624** images belonging to 2 classes.

## 5) EDA

In [None]:
train_df.head()

In [None]:
count_imgs = 16

normal_path = f'{TRAIN_PATH}/NORMAL'
pneumonia_path = f'{TRAIN_PATH}/PNEUMONIA'

sample_normal_imgs = os.listdir(normal_path)[:count_imgs]
sample_pneumonia_imgs = os.listdir(pneumonia_path)[:count_imgs]

In [None]:
counter = 0

normal_imgs_path = [normal_path + '/' + i for i in sample_normal_imgs]
pneumonia_imgs_path = [pneumonia_path + '/' + j for j in sample_pneumonia_imgs]

all_imgs = normal_imgs_path + pneumonia_imgs_path

random.shuffle(all_imgs)

In [None]:
plt.figure(figsize=(28, 10))

for img_path in all_imgs:
    plt.subplot(4, 8, counter + 1)
    
    img = read_img(img_path, IMAGE_SIZE)
    
    label = img_path[len(TRAIN_PATH) + 1: img_path.rfind('/')]
    
    plt.imshow(img)
    plt.title(label)
    plt.axis('off')
    
    counter += 1

In [None]:
counter = 0

plt.figure(figsize=(28, 20))

for img_path in all_imgs[:8]:
    plt.subplot(4, 2, counter + 1)
    
    img = read_img(img_path, IMAGE_SIZE)    
    
    plt.hist(img.ravel()) 
    plt.title(counter + 1)
    plt.axis('off')
    
    counter += 1

In [None]:
plt.figure(figsize = (15,6))

fig = sns.countplot(train_df, x='label')

plt.xticks(rotation=0)
plt.show()

What we see?

1. The data is imbalanced
2. Not enough data to train CNN

### 5.1) Applying CLAHE Filter

In [None]:
sample1 = train_df.iloc[0]

sample1_label = sample1['label']
sample1_path = sample1['path']

sample2 = train_df.iloc[0]

sample2_label = sample2['label']
sample2_path = sample2['path']

sample1_img = read_img(f'{TRAIN_PATH}/{sample1_label}/{sample1_path}', IMAGE_SIZE)
sample2_img = read_img(f'{TRAIN_PATH}/{sample2_label}/{sample2_path}', IMAGE_SIZE)

In [None]:
plt.figure(figsize=(28, 10))

plt.subplot(1, 3, 1)

plt.imshow(sample1_img)
plt.title('Without Filter')

plt.subplot(1, 3, 2)

sample2_img_bw = cv2.cvtColor(sample2_img, cv2.COLOR_BGR2GRAY)

clahe = cv2.createCLAHE(clipLimit=5)
sample2_img = clahe.apply(sample2_img_bw)

plt.imshow(sample2_img)
plt.title('With Filter')

plt.subplot(1, 3, 3)

_, ordinary_img = cv2.threshold(sample2_img_bw, 155, 255, cv2.THRESH_BINARY)

plt.imshow(ordinary_img)

plt.axis('off')
plt.show()

As we can see the contract of the image is being enchanced, as I think this filter would reduce the noise in the image

In [None]:
plt.figure(figsize=(28, 10))

plt.subplot(1, 2, 1)
plt.hist(sample1_img.ravel())

plt.subplot(1, 2, 2)
plt.hist(sample2_img.ravel())

plt.show()

## 6) CNN Model

In [None]:
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()

        self.model = models.vgg16(weights='DEFAULT')
        self.features = self.model.features
        
        self.features1 = self.features[:6]
        self.features2 = self.features[6:10]
        self.features3 = self.features[10:17]
        self.features4 = self.features[17:30]


    def forward(self, images):
        with torch.no_grad():
            f1 = self.features1(images)
            f2 = self.features2(f1)
            f3 = self.features3(f2)
            f4 = self.features4(f3)

        return f2, f3, f4

In [None]:
class Classifier(nn.Module):
    def __init__(self, num_classes=2):
        super(Classifier, self).__init__()

        self.num_classes = num_classes

        self.conv1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=(1, 1), padding=0),
            nn.ReLU(),
        )

        self.fc = nn.Linear(128, self.num_classes)
        
    def forward(self, f2, f3, f4):
        features = self.conv1(f4)
        features = self.conv2(features)
        
        features = F.interpolate(features, scale_factor=2, mode='bilinear') + f3
        
        features = self.conv3(features)
        
        features = F.interpolate(features, scale_factor=2, mode='bilinear') + f2
        
        fc = torch.flatten(F.adaptive_avg_pool2d(features, 1), start_dim=1)
        scores = F.softmax(self.fc(fc), dim=1)
        
        # Compute CAMS
        with torch.no_grad():
            batch, composition, height, width = features.shape
            
            features = features.permute(0, 1, 2, 3)
            features = features.view(batch, composition, height * width)
            
            w = self.fc.weight.data.unsqueeze(0).repeat(batch, 1, 1)
            
            cams = torch.matmul(w, features)
            
            cams = self._normalize_cams(cams)
            cams = cams.view(batch, self.num_classes, height, width)

        return fc, cams
    
    def _normalize_cams(self, cam):
        cam = cam - cam.min(dim=-1)[0].unsqueeze(-1)
        cam = cam / cam.max(dim=-1)[0].unsqueeze(-1)

        return cam

### 6.1) Functions

In [None]:
def draw_heatmap(image, vgg16, model):
    model.eval()
    vgg16.eval()
    
    f2, f3, f4 = vgg16(image.unsqueeze(0))
    _, cams = model(f2, f3, f4)
    
    plt.matshow(heatmap.squeeze().numpy())

### 6.2) Training

In [None]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    pin_memory=False,
    drop_last=False,
    shuffle=True,
)

In [None]:
vgg16_model = VGG16().to(DEVICE)
classifier_model = Classifier(num_classes=2).to(DEVICE)

In [None]:
def train_epoch(loader, vgg16, model, optimizer, loss_fn, epoch, lr_scheduler=None):
    model.train()
    vgg16.eval()
    
    losses = []
    
    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        f2, f3, f4 = vgg16(images)
        scores, _ = model(f2, f3, f4)

        loss = loss_fn(scores, labels)
        
        losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()        
        optimizer.step()
        
        if lr_scheduler is not None:
            lr_scheduler.step()
        
        print(f'-> Batch Loss[{loss.item()}]')
        
    print(f'|| Epoch[{epoch}]: Loss[{np.mean(losses)}]')
    
    return np.mean(losses)

In [None]:
def check_accuracy(loader, vgg16, model, epoch):
    model.eval()
    
    accuracies = []
    
    for images, labels in loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        f2, f3, f4 = vgg16(images)
        predicted, _ = model(f2, f3, f4)
        
        accuracies.append((torch.mean((predicted.argmax(1) == labels).float()) * 100).item())
        
    print(f'=> Epoch[{epoch}]: Accuracy[{np.mean(accuracies)}]')
    
    return np.mean(accuracies)

In [None]:
optimizer = optim.Adam(classifier_model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()
lr_scheduler = scheduler.MultiStepLR(optimizer=optimizer, milestones=LR_DECAY_EPOCH, gamma=LR_DECAY)

In [None]:
starting_epoch = 1

In [None]:
for epoch in range(starting_epoch, EPOCHS):
    train_epoch(
        train_loader,
        vgg16_model,
        classifier_model,
        optimizer,
        loss_fn,
        epoch,
        lr_scheduler
    )

In [None]:
draw_heatmap(train_dataset[5][0], vgg16_model, classifier_model)