In [1]:
import sys 
sys.path.append('/home/lumargot/trachoma/src/py')

import pandas as pd
import numpy as np
from itertools import combinations
import torch
import SimpleITK as sitk
import matplotlib.pyplot as plt 
import os 
from torchvision.ops import nms
import random
from tqdm import tqdm
from visualization import *

from utils import *
from nets.segmentation import FasterRCNN
from loaders.tt_dataset import TTDatasetBX,TTDataModuleBX, BBXImageTrainTransform, BBXImageEvalTransform, BBXImageTestTransform

  from .autonotebook import tqdm as notebook_tqdm
A new version of Albumentations is available: '2.0.6' (you have '2.0.3'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


In [29]:
mount_point = "/CMF/data/lumargot/trachoma/"

df_train = pd.read_csv('/CMF/data/lumargot/trachoma/csv_updated/mtss_pret_combined_train_fold0_train_train.csv')
df_val = pd.read_csv('/CMF/data/lumargot/trachoma/csv_updated/mtss_pret_combined_train_fold0_train_test.csv')
# df_train = pd.read_csv('/CMF/data/lumargot/trachoma/csv_updated/mtss_pret_combined_train_fold0_train_train.csv')

concat_labels=['overcorrection', 'ECA', 'Gap', 'Fleshy']
drop_labels = ['Short Incision', 'Reject']

img_column= "filename" 
class_column = 'class'
label_column = 'label'

map ={ 1:'Healthy', 2:'Entropion', 3:'Overcorrection'}


In [30]:
df_train = remove_labels(df_train, class_column, label_column, drop_labels=drop_labels, concat_labels=concat_labels)
df_val = remove_labels(df_val, class_column, label_column, drop_labels=drop_labels, concat_labels=concat_labels)


In [31]:
ttdata = TTDataModuleBX(df_train, df_val, df_train, batch_size=16, num_workers=2, img_column='filename',severity_column='sev', 
                        mount_point=mount_point, class_column= class_column,
                        train_transform=BBXImageTrainTransform(), valid_transform=BBXImageEvalTransform(), test_transform=BBXImageTestTransform())
ttdata.setup()
dataload = ttdata.val_dataloader()
ds = ttdata.test_ds

In [None]:
class SimpleBalancedBoxBatcher:
    def __init__(self, batch_size=16):
        self.batch_size = batch_size
    
    def collate_fn(self, batch):
        """Custom collate function that balances classes in a batch"""
        images = torch.stack([item['img'] for item in batch])
       
        masks = torch.stack([item['mask'] for item in batch])

        original_boxes = [item['boxes'] for item in batch]
        original_labels = [item['labels'] for item in batch]
        
        # Count boxes per class across the entire batch
        all_labels = []
        for item in batch:
            all_labels.append(item['labels'])
        
        # Find minimum count across classes
        all_labels = torch.cat(all_labels)
        classes, counts = torch.unique(all_labels, return_counts=True)
        min_count = counts.min().item()
        
        targets = []
        for i, (boxes, labels, mask) in enumerate(zip(original_boxes, original_labels, masks)):
            
            img_classes = torch.unique(labels)
            
            # For each class in this image, sample boxes
            img_balanced_boxes, img_balanced_labels = [], []

            for cls in img_classes:
                mask = (labels == cls)
                cls_boxes = boxes[mask]
                cls_labels = labels[mask]
                
                # Calculate how many to keep of this class
                # (divide min_count proportionally among images)

                proportion = len(cls_boxes) / counts[classes == cls].item()
                keep_count = max(1, int(min_count * proportion))
                keep_count = min(keep_count, len(cls_boxes))
                
                # Randomly sample
                if len(cls_boxes) > keep_count:
                    indices = torch.randperm(len(cls_boxes))[:keep_count]
                    cls_boxes = cls_boxes[indices]
                    cls_labels = cls_labels[indices]
                
                img_balanced_boxes.append(cls_boxes)
                img_balanced_labels.append(cls_labels)
            
            if img_balanced_boxes:
                img_boxes = torch.cat(img_balanced_boxes)
                img_labels = torch.cat(img_balanced_labels)
            else:
                img_boxes = boxes
                img_labels = labels
            
            dic_i = {'labels': img_labels, 
                     'boxes': img_boxes,
                     'mask': mask}
            targets.append(dic_i)

        return images, targets


In [106]:
dataset = TTDatasetBX(df_train, transform=BBXImageTrainTransform(),img_column='filename', 
                        mount_point=mount_point, class_column= class_column)

In [None]:
sampler = SimpleBalancedBoxBatcher(batch_size=16)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size = 16,
    num_workers=0,
    collate_fn=sampler.collate_fn
)

for IDX, batch in enumerate(tqdm(loader)):
    # targets = batch
    imgs, targets = batch
    # imgs = targets.pop('img', None)

    plt.figure(figsize=(10,20))
    for j in range(len(targets)):
        img = imgs[j].permute(1,2,0) 
        boxes = targets[j]['boxes']
        labels = targets[j]['labels']
        plt.subplot(8,2,j+1)
        ax = plt.gca()
        ax.imshow(img)
        for k in range(labels.shape[0]):
            box = boxes[k]
            label = labels[k]
            x1, y1, x2, y2 = box
            width, height = x2 - x1, y2 - y1
            if label == 1: color = 'green'
            elif label == 2: color = 'blue'
            else: color = 'cyan'
            rect = Rectangle((x1, y1), width, height, fill=False, color=color, linewidth=1.5)
            ax.add_patch(rect)
    if IDX == 2:
        break

In [None]:
for IDX, batch in enumerate(tqdm(dataload)):
    imgs, targets = batch
    labels = [t['labels'] for t in targets]
    classes, counts = torch.unique(torch.cat(labels), return_counts=True)
    print(counts)
    # plt.figure(figsize=(10,20))
    # for j in range(len(targets)):
    #     img = imgs[j].permute(1,2,0) 
    #     boxes = targets[j]['boxes']
    #     labels = targets[j]['labels']

    #     plt.subplot(8,2,j+1)
    #     ax = plt.gca()
    #     ax.imshow(img)

    #     for k in range(labels.shape[0]):
    #         box = boxes[k]
    #         label = labels[k]

    #         x1, y1, x2, y2 = box
    #         width, height = x2 - x1, y2 - y1

    #         if label == 1: color = 'green'
    #         elif label == 2: color = 'blue'
    #         else: color = 'cyan'

    #         rect = Rectangle((x1, y1), width, height, fill=False, color=color, linewidth=1.5)
    #         ax.add_patch(rect)

    # break

In [None]:
img_path = 