# AI for Health
## Detecting Active Tuberculosis Bacilli on TB Smears
### Group: Luciana, Seohee and Irma

Importing Libraries:

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.models import resnet18
import pytorch_lightning as pl
from tqdm import tqdm
from pytorch_lightning.callbacks import EarlyStopping

Defining Paths:

In [None]:
TBWG_ROOT = Path().home() / "datasets" / "tb-wellgen-smear"
CONTEST_DIR = TBWG_ROOT / "supplementary"
IMAGE_ROOT = TBWG_ROOT / "images"
TABLE_DIR = TBWG_ROOT / "v1"
LOGGER_DIR = Path().home() / "project" / "logs"

tb_labels_df = pd.read_csv(TABLE_DIR / "tb-labels.csv")
weights_path = os.path.join("Pretrained_Weights", "resnet18-f37072fd.pth")

Split Data into Training and Test Sets:

In [None]:
def random_split(df, val_size):
    """Helper function to split training and test sets"""
    n = len(df)
    idxs = np.random.randint(0, n, size=val_size)
    test_df = df.iloc[idxs].copy().reset_index(drop=True)
    train_df = df.drop(index=idxs).reset_index(drop=True)
    return train_df, test_df

train_df, test_df = random_split(tb_labels_df, 1000) 

Helper Function to Fetch Image:

In [None]:
def get_image(idx, df, numpy=False, to_float=False):
    """Helper function to fetch image from disk"""
    path = df.loc[idx, "file_path"]
    assert os.path.exists(path)
    img = Image.open(path)
    if numpy:
        arr = np.asarray(img)
        if to_float:
            arr = arr / 255.0
        return arr
    else:
        return img

img = get_image(5, train_df)
img

Compute Means and Standard Deviations for Normalization:

In [None]:
means = []
stds = []

for idx in range(1000):
    arr = get_image(idx, train_df, numpy=True, to_float=True)
    means.append(arr.mean(axis=(0, 1)))
    stds.append(arr.std(axis=(0, 1)))

MEANS = np.vstack(means).mean(axis=0)
STDS = np.vstack(stds).mean(axis=0)

print(MEANS, STDS)

Create a custom Dataset: 

In [None]:
class tbDataset(Dataset):
    def __init__(self, df, transform=None, train=True):
        self.df = df
        self.transform = transform
        self.train = train
        if self.train:
            self.train = self.df.sample(frac=1.0, ignore_index=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = get_image(idx, self.df, numpy=True, to_float=True)
        label = int(self.df.loc[idx, "tb_positive"])
        if self.transform:
            img = self.transform(img)
        return img, label

Create DataLoaders in Lightning DataModule Class:

In [None]:
class tbDataModule(pl.LightningDataModule):
    def __init__(self, df, batch_size=64, num_workers=8):
        super().__init__()
        self.full_df = df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224), antialias=True), # input size for ResNet 
                transforms.Normalize(MEANS, STDS),
            ]
        )
        self.setup()

    def setup(self, stage=None):
        if stage == "train" or stage == None:
            self.train_df, self.test_df = random_split(self.full_df, 1000)
        else:
            pass

    def train_dataloader(self):
        return DataLoader(
            tbDataset(self.train_df, transform=self.transform),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            tbDataset(self.test_df, transform=self.transform),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

Define LightningModule for Model Training:

In [None]:
class tbModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3, num_classes=2):
        super().__init__()
        self.lr = learning_rate
        self.num_classes = num_classes
        self.net = resnet18(weights=None, num_classes=self.num_classes)
        self.net.double()

        pretrained_weights = torch.load(weights_path) # Load the pre-trained weights
        self.net.load_state_dict(pretrained_weights, strict=False) # Load the weights into the model, ignoring mismatched layers
        self.net.fc = nn.Linear(self.net.fc.in_features, self.num_classes) # Replace the final fully connected layer

        self.test_accuracy = Accuracy(task="binary", num_classes=self.num_classes)

    def forward(self, X):
        return self.net.forward(X)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        imgs, labels = imgs.to(self.device), labels.to(self.device)
        logits = self.net.forward(imgs)
        loss = nn.functional.cross_entropy(logits, labels)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        imgs, labels = imgs.to(self.device), labels.to(self.device)
        logits = self.net.forward(imgs)
        loss = nn.functional.cross_entropy(logits, labels) 
        preds = torch.argmax(logits, dim=-1)
        self.test_accuracy.update(preds, labels)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

Train the Model:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = tbDataModule(tb_labels_df)

model = tbModel()
model = model.to(device)

logger = pl.loggers.CSVLogger(save_dir=LOGGER_DIR, name="tb_demo7")

# Using early stopping callback
early_stop_callback = EarlyStopping(
    monitor='val_loss', # loss calculated on the validation set
    patience=3, # early stop if loss doesn't improve after 3 epochs
    verbose=True,
    mode='min'
)

trainer = pl.Trainer(
    accelerator="auto",
    max_epochs = 50, 
    logger=logger,
    callbacks=[early_stop_callback] # early stopping
)

trainer.fit(model, data)
trainer.test(model, data)