<a href="https://colab.research.google.com/github/kirannyaupane11/bacterial-Classification-and-Quantification/blob/main/Bacterial_Classification_and_Quantification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


CUDA available: True
GPU: Tesla T4


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
DATA_ROOT = "/content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset"
import os
print(os.path.exists(DATA_ROOT), DATA_ROOT)


True /content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset


In [4]:
!pip -q install albumentations opencv-python


In [5]:

import glob
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A

train_img_dir = f"/content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/train/patches/fluorescence"
train_msk_dir = f"/content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/train/patches/masks"

imgs = sorted(glob.glob(train_img_dir + "/*.tif"))
msks = sorted(glob.glob(train_msk_dir + "/*.tif"))

print("Train patches:", len(imgs), len(msks))

# Added a check to prevent IndexError if lists are empty
if imgs and msks:
    print("Example:", imgs[0], msks[0])
else:
    print("No images or masks found. Check DATA_ROOT path and directory contents.")


Train patches: 28 28
Example: /content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/train/patches/fluorescence/JE2NileRed_oilp22_PMP_101220_001_1.tif /content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/train/patches/masks/JE2NileRed_oilp22_PMP_101220_001_1.tif


In [14]:
from sklearn.model_selection import train_test_split

train_imgs, val_imgs, train_msks, val_msks = train_test_split(
    imgs, msks, test_size=0.2, random_state=42
)
print(len(train_imgs), len(val_imgs))


22 6


In [15]:
class BacteriaSegDataset(Dataset):
    def __init__(self, img_paths, mask_paths, augment=False):
        self.img_paths = img_paths
        self.mask_paths = mask_paths

        if augment:
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Rotate(limit=20, p=0.5),
            ])
        else:
            self.transform = None

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Explicitly load image and mask as grayscale to ensure consistent channel dimension
        img = cv2.imread(self.img_paths[idx], cv2.IMREAD_GRAYSCALE)
        msk = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        # Handle potential None if image reading fails
        if img is None:
            raise FileNotFoundError(f"Could not read image: {self.img_paths[idx]}")
        if msk is None:
            raise FileNotFoundError(f"Could not read mask: {self.mask_paths[idx]}")

        # ensure float32
        img = img.astype(np.float32)
        msk = msk.astype(np.float32)

        # normalise image to 0..1
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        # binarise mask to 0/1
        msk = (msk > 0).astype(np.float32)

        if self.transform:
            aug = self.transform(image=img, mask=msk)
            img, msk = aug["image"], aug["mask"]

        # add channel dimension: (1,H,W)
        # Now img and msk should consistently be (H,W) arrays, so unsqueeze(0) will result in (1,H,W)
        img = torch.tensor(img).unsqueeze(0)
        msk = torch.tensor(msk).unsqueeze(0)

        return img, msk


In [16]:
train_ds = BacteriaSegDataset(train_imgs, train_msks, augment=True)
val_ds   = BacteriaSegDataset(val_imgs, val_msks, augment=False)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)


In [20]:
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DoubleConv(1, 32)
        self.d2 = DoubleConv(32, 64)
        self.d3 = DoubleConv(64, 128)
        self.pool = nn.MaxPool2d(2)

        self.b = DoubleConv(128, 256)

        self.u3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.c3 = DoubleConv(256, 128)
        self.u2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.c2 = DoubleConv(128, 64)
        self.u1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.c1 = DoubleConv(64, 32)

        self.out = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(self.pool(x1))
        x3 = self.d3(self.pool(x2))
        xb = self.b(self.pool(x3))

        x = self.u3(xb)
        x = self.c3(torch.cat([x, x3], dim=1))
        x = self.u2(x)
        x = self.c2(torch.cat([x, x2], dim=1))
        x = self.u1(x)
        x = self.c1(torch.cat([x, x1], dim=1))
        return self.out(x)


In [21]:
def dice_coeff(pred, target, eps=1e-7):
    pred = pred.view(-1)
    target = target.view(-1)
    inter = (pred * target).sum()
    return (2*inter + eps) / (pred.sum() + target.sum() + eps)

def iou_coeff(pred, target, eps=1e-7):
    pred = pred.view(-1)
    target = target.view(-1)
    inter = (pred * target).sum()
    union = pred.sum() + target.sum() - inter
    return (inter + eps) / (union + eps)

bce = nn.BCEWithLogitsLoss()

def loss_fn(logits, target):
    probs = torch.sigmoid(logits)
    dice = dice_coeff((probs > 0.5).float(), target)
    return bce(logits, target) + (1 - dice)


In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, total_dice, total_iou = 0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        with torch.set_grad_enabled(train):
            logits = model(x)
            loss = loss_fn(logits, y)

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            d = dice_coeff(preds, y).item()
            j = iou_coeff(preds, y).item()

            if train:
                opt.zero_grad()
                loss.backward()
                opt.step()

        total_loss += loss.item()
        total_dice += d
        total_iou  += j

    n = len(loader)
    return total_loss/n, total_dice/n, total_iou/n

best_val = 0
for epoch in range(1, 50):
    tr = run_epoch(train_loader, train=True)
    va = run_epoch(val_loader, train=False)

    print(f"Epoch {epoch:02d} | "
          f"train loss {tr[0]:.3f} dice {tr[1]:.3f} iou {tr[2]:.3f} | "
          f"val loss {va[0]:.3f} dice {va[1]:.3f} iou {va[2]:.3f}")

    if va[1] > best_val:
        best_val = va[1]
        torch.save(model.state_dict(), "best_unet.pth")


Epoch 01 | train loss 1.514 dice 0.170 iou 0.097 | val loss 1.675 dice 0.045 iou 0.023
Epoch 02 | train loss 1.293 dice 0.291 iou 0.176 | val loss 1.657 dice 0.051 iou 0.026
Epoch 03 | train loss 1.232 dice 0.286 iou 0.173 | val loss 1.073 dice 0.378 iou 0.251
Epoch 04 | train loss 1.137 dice 0.342 iou 0.217 | val loss 1.441 dice 0.060 iou 0.031
Epoch 05 | train loss 1.134 dice 0.300 iou 0.185 | val loss 1.450 dice 0.000 iou 0.000
Epoch 06 | train loss 1.155 dice 0.250 iou 0.155 | val loss 1.472 dice 0.001 iou 0.001
Epoch 07 | train loss 1.187 dice 0.186 iou 0.106 | val loss 1.507 dice 0.015 iou 0.008
Epoch 08 | train loss 1.226 dice 0.121 iou 0.065 | val loss 1.519 dice 0.000 iou 0.000
Epoch 09 | train loss 1.216 dice 0.114 iou 0.063 | val loss 1.321 dice 0.000 iou 0.000
Epoch 10 | train loss 1.277 dice 0.045 iou 0.023 | val loss 1.291 dice 0.048 iou 0.025
Epoch 11 | train loss 1.117 dice 0.186 iou 0.106 | val loss 1.178 dice 0.247 iou 0.148
Epoch 12 | train loss 1.284 dice 0.006 iou 

In [23]:
test_img_dir = f"/content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/test/fluorescence"
test_msk_dir = f"/content/drive/MyDrive/Staphylocococcus aureus datasets/DeepBacs_Data_Segmentation_Staph_Aureus_dataset/fluorescence_dataset/test/masks"

test_imgs = sorted(glob.glob(test_img_dir + "/*.tif"))
test_msks = sorted(glob.glob(test_msk_dir + "/*.tif"))

if not test_imgs or not test_msks:
    print("Warning: No test images or masks found. Skipping test evaluation.")
else:
    test_ds = BacteriaSegDataset(test_imgs, test_msks, augment=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

    model.load_state_dict(torch.load("best_unet.pth", map_location=device))
    model.eval()

    test_loss, test_dice, test_iou = run_epoch(test_loader, train=False)
    print("TEST | loss:", test_loss, "dice:", test_dice, "iou:", test_iou)


TEST | loss: 0.8068059086799622 dice: 0.6257587015628815 iou: 0.46699504256248475


In [24]:
def quantify(mask01):
    # mask01: numpy array 0/1
    coverage = mask01.mean() * 100.0

    mask_u8 = (mask01 * 255).astype(np.uint8)
    num_labels, labels = cv2.connectedComponents(mask_u8)
    # subtract background label 0
    count = max(0, num_labels - 1)
    return coverage, count

for i, (x, y) in enumerate(test_loader):
    x = x.to(device)
    with torch.no_grad():
        logits = model(x)
        pred = (torch.sigmoid(logits) > 0.5).float().cpu().numpy()[0,0]

    cov, cnt = quantify(pred)
    print(f"Test image {i+1}: coverage={cov:.2f}% | count={cnt}")


Test image 1: coverage=2.94% | count=31
Test image 2: coverage=4.97% | count=33
Test image 3: coverage=6.40% | count=62
Test image 4: coverage=3.49% | count=58
Test image 5: coverage=1.74% | count=34
