In [1]:
import torch
import torchvision
from torch.utils.data import Dataset 
import  torchvision.transforms.v2 as T
import albumentations as A
from torch.utils.data import DataLoader, random_split
import os , json 
from PIL import Image
import  matplotlib.pyplot as plt
import numpy as np 
from matplotlib.patches import Rectangle
from ultralytics import YOLO


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_DIR = "./"
DATA_DIR = "stool"
ANN_DIR = "labels"

In [3]:

class StoolDataset(Dataset):
    def __init__(self, labels_dir, transform=None, images_dir=None):
        self.labels_dir = labels_dir
        self.transform = transform
        self.images_dir = images_dir  
        self.annotations = self._load_annotations()

    def _load_annotations(self):
        annotations = []
        base_url = "http://sitoscope.naamii.org.np/media/"
        for file in os.listdir(self.labels_dir):
            if file.endswith('.json'):
                file_path = os.path.join(self.labels_dir, file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                    for img_key, img_data in data.items():
                        relative_path = img_data['filename'].replace(base_url, "")
                        img_path = os.path.join(self.images_dir, relative_path)
                    
                        if os.path.exists(img_path):
                            annotations.append(img_data)
                        else:
                            print(f"Skipping {img_path} - image not found")
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")
        return annotations

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        ann = self.annotations[idx]
        img_url = ann['filename']  

        # I don't know if I am dumb or what but the annotations are in url format so had to use an extra logic 
        # Please don't judge Naamii employees :) 
        
        base_url = "http://sitoscope.naamii.org.np/media/"
        relative_path = img_url.replace(base_url, "")  # e.g., "stool/Aaurahi/2023-08-07/..."
        img_path = os.path.join(self.images_dir, relative_path)  # e.g., "stool/Aaurahi/..."

        if not os.path.exists(img_path):
            print(f"Skipping {img_path} - image not found")
            return None  # Return None to skip this item

        # Load image loaclly 
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image at {img_path}: {e}")
            return None

        # apply various preprocessings 
        

        # Prepare target (bounding boxes and labels)
        regions = ann['regions']
        boxes = []
        t_boxes = []
        labels = []
        for region in regions:
            shape = region['shape_attributes']
            if shape['name'] == 'rect':
                x, y, w, h = shape['x'], shape['y'], shape['width'], shape['height']
                boxes.append([x, y,  x+w,  y+h])  
            label = region['region_attributes']
            if label['RECTANGLE'] == 'Giardia': 
                labels.append('Giardia')
            else: 
                labels.append('Cryptosporidium')

        if self.transform:
            img , boxes = self.transform(img, boxes)

        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': img_path
        }

        return img, target


In [4]:
#aplly processings 
transform = T.Compose([
    #T.RandomResizedCrop(size=(480, 480), antialias=True),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(), 
    #T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# I downloaded half the dataset due to obvious computational reasons and also I am too lazy to remove the annotations manually 
def skipper(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None
    return tuple(zip(*batch))




In [5]:
# Split dataset 
def split_and_load(dataset):
    train_set , test_set = random_split(dataset  , [0.8, 0.2])
    train_loader = DataLoader(train_set, batch_size=16 , shuffle=True, collate_fn=skipper)
    test_loader = DataLoader(test_set, batch_size=16 , shuffle=True, collate_fn=skipper)
    return train_loader , test_loader

In [6]:
def draw_bounding_box(img, label):
    fig , (ax1,ax2) = plt.subplots(1 ,2 , figsize = (20,10))
    ax1.axis('off')
    ax1.set_title('Raw Image')
    ax1.imshow(img)
    ax2.imshow(img)
    ax2 = plt.gca()
    ax2.axis('off')
    ax2.set_title('Annotated Image')
    for box,name in zip(label['boxes'], label['labels']) : 
        x_min , y_min , x_max, y_max = box 
        rect = Rectangle((x_min , y_min) , x_max - x_min , y_max - y_min  , edgecolor = 'b', facecolor = 'None', label=name)
        x_text = (x_min + x_max) / 2  # Center horizontally
        y_text = y_min - 50  # 50 pixels above the top edge
        ax2.text(x_text, y_text, name, ha='center', va='bottom', color='blue', clip_on=False)
        ax2.add_patch(rect)


In [10]:
dataset = StoolDataset(labels_dir=ANN_DIR ,  transform =transform, images_dir="/Volumes/HDD/ayyp/")

Skipping /Volumes/HDD/ayyp/standard/2022-12-22/Standard_S_20221222_D1_9E0BJ/3/Standard_S_20221222_D1_9E0BJ_S3_I15_B.png - image not found
Skipping /Volumes/HDD/ayyp/stool/Janak%20Nandini/2023-08-07/S_MP-janak-nandini_20230807_IP5OQ/2/S_MP-janak-nandini_20_rcPAxrU.jpg - image not found
Skipping /Volumes/HDD/ayyp/stool/Saphebagar/2024-02-26/S_SP-saphebagar_20240226_8YTKS/3/S_SP-saphebagar_20240226_8Y_US0jJZs.jpg - image not found
Skipping /Volumes/HDD/ayyp/stool/Damak/2023-08-17/S_P1-damak_20230817_HEP80/1/S_P1-damak_20230817_HEP80_S1_I9_S.jpg - image not found
Skipping /Volumes/HDD/ayyp/stool/Loharpatti/2023-09-26/S_MP-loharpatti_20230926_2UGIG/2/S_MP-loharpatti_20230926_2U_eCNygLr.jpg - image not found
Skipping /Volumes/HDD/ayyp/stool/Kamal/2023-10-14/S_P1-kamal_20231014_QROIX/3/S_P1-kamal_20231014_QROIX_S3_I12_S.jpg - image not found
Skipping /Volumes/HDD/ayyp/stool/Kamal/2023-10-14/S_P1-kamal_20231014_QROIX/3/S_P1-kamal_20231014_QROIX_S3_I12_S.jpg - image not found
Skipping /Volumes/

In [1]:

test_loader, train_loader = split_and_load(dataset)
image , label = next(iter(train_loader))
print(len(image))
for idx , (img, lbl) in enumerate(zip(image, label)):
    if idx >= 5: 
        break
    img = img.permute(1, 2, 0) # for orientation very important 
    draw_bounding_box(img, lbl) #checking the images and bounding box 


NameError: name 'split_and_load' is not defined

In [9]:
## TODO 
# write name around bounding box  -----> done 
# train a pretrained vit 
# train own vit 
