In [1]:
import os
import json
import csv
import random
import pickle
import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import label
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import cv2
from dataset import GlaucomaDataset

from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, array_to_img, img_to_array
from PIL import Image
from sklearn.model_selection import train_test_split



from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import maskrcnn_resnet50_fpn

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision
from torchvision.transforms import functional as F


In [2]:


class GlaucomaDataset(Dataset):
    def __init__(self, root_dir, split='train', output_size=(256,256), max_images=None):
        self.output_size = output_size
        self.root_dir = root_dir
        self.split = split
        self.images = []
        self.segs = []
        self.max_images = max_images
        # Load data index
        for direct in self.root_dir:
                    self.image_filenames = []
                    for path in os.listdir(os.path.join(direct, "Images_Square")):
                        if(not path.startswith('.')):
                            self.image_filenames.append(path)

                    num_images = 0
                    for k in range(len(self.image_filenames)):
                        # Skip loading if max_images is specified and the limit has been reached
                        if max_images is not None and num_images >= max_images:
                            break

                        print('Loading {} image {}/{}...'.format(split, k, len(self.image_filenames)), end='\r')
                        img_name = os.path.join(direct, "Images_Square", self.image_filenames[k])
                        img = np.array(Image.open(img_name).convert('RGB'))

                        if split != 'test':
                            seg_name = os.path.join(direct, "Masks_Square", self.image_filenames[k][:-3] + "png")
                            mask = np.array(Image.open(seg_name, mode='r'))
                            od = (mask==1.).astype(np.float32)
                            oc = (mask==2.).astype(np.float32)
                            
                            # Check if both masks are not empty, i.e., they contain at least one non-zero pixel
                            if np.any(od) and np.any(oc):
                                img = transforms.functional.to_tensor(img)
                                img = transforms.functional.resize(img, output_size, interpolation=Image.BILINEAR)
                                self.images.append(img)
                                od = torch.from_numpy(od[None,:,:])
                                oc = torch.from_numpy(oc[None,:,:])
                                od = transforms.functional.resize(od, output_size, interpolation=Image.Resampling.NEAREST)
                                oc = transforms.functional.resize(oc, output_size, interpolation=Image.Resampling.NEAREST)
                                self.segs.append(torch.cat([od, oc], dim=0))
                                num_images += 1

                    print('Succesfully loaded {} dataset.'.format(split) + ' '*50)

    def __len__(self):
        return len(self.images)
   


    def __getitem__(self, idx):
        # load image
        img = self.images[idx]
        # load segmentation masks (for both optic disk and optic cup)
        seg = self.segs[idx]
        # For instance segmentation, each mask should be a binary mask of shape (H, W).
        # Therefore, we need to split the combined mask into two separate masks.
        od_mask, oc_mask = seg[0], seg[1]

        # Find bounding boxes around each mask. The bounding box is represented as
        # [xmin, ymin, width, height], which is the format expected by Mask R-CNN.
        od_bbox = self.mask_to_bbox(od_mask.numpy())
        oc_bbox = self.mask_to_bbox(oc_mask.numpy())

        # The labels are a tensor of class IDs. In this case, you might want to use
        # 1 for optic disk and 2 for optic cup, as you did when creating the masks.
        labels = torch.tensor([1, 2], dtype=torch.int64)

        # Now, we need to put the masks and bounding boxes into the right format.
        # The masks should be a tensor of shape (num_objs, H, W),
        # and the bounding boxes should be in a (num_objs, 4) tensor.
        masks = torch.stack([od_mask, oc_mask])
        boxes = torch.tensor([od_bbox, oc_bbox])

        # Pack the bounding boxes and labels into a dictionary
        target = {"boxes": boxes, "labels": labels, "masks": masks}

        return img, target

    @staticmethod
    def mask_to_bbox(mask):
        # Find the bounding box of a binary mask.
        # This method assumes that the input is a binary mask with 0s and 1s.
        pos = np.where(mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax - xmin, ymax - ymin]


In [3]:

def get_instance_segmentation_model(num_classes):
    # Load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2()

    # Get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # And replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model


In [4]:
root_dirs = [ "ORIGA","G1020"]
val_dir = [ "REFUGE"]
lr = 1e-4
batch_size = 8
num_workers = 0
total_epoch = 1

train_set = GlaucomaDataset(root_dirs, 
                          split='train', max_images=50)

val_set = GlaucomaDataset(val_dir, 
                        split='val', max_images=50)

train_loader = DataLoader(train_set, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          num_workers=num_workers,
                          pin_memory=True,
                         )
val_loader = DataLoader(val_set, 
                        batch_size=batch_size, 
                        shuffle=False, 
                        num_workers=num_workers,
                        pin_memory=True,
                        )


Loading train image 32/650...

  img = transforms.functional.resize(img, output_size, interpolation=Image.BILINEAR)


Succesfully loaded train dataset.                                                  
Succesfully loaded train dataset.                                                  
Succesfully loaded val dataset.                                                  


In [5]:
# Device
# device = torch.device("cuda:0")
device = torch.device("mps")


# Network
model = get_instance_segmentation_model(num_classes=2)

# Loss
seg_loss = torch.nn.BCELoss(reduction='mean')

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
#optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [6]:
# Move model to the device
model.to(device)

# If more than one class is used (background is also considered a class)
num_classes = 3 

# Initialize the metric trackers
train_loss = []
val_loss = []

# Epoch loop
for epoch in range(total_epoch):
    model.train()
    print("Epoch {}/{}".format(epoch+1, total_epoch))
    print('-' * 10)
    
    running_loss = 0.0
    
    # Train loop
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        targets = [{k: v.squeeze(0).to(device) for k,v in  targets.items()}]

        # Forward pass
        output = model(inputs, targets)
        loss = sum(loss for loss in output.values())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        running_loss += loss.item() * inputs.size(0)
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch+1, 
                                                                       i * len(inputs), 
                                                                       len(train_loader.dataset),
                                                                       100. * i / len(train_loader),
                                                                       loss.item()))
    epoch_loss = running_loss / len(train_loader.dataset)
    train_loss.append(epoch_loss)
    print('Training Loss: {:.4f}'.format(epoch_loss))

    # Validation loop
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(val_loader):
            inputs = list(img.to(device) for img in inputs)
            targets = [{k: v.to(device) for k, v in t.items} for t in targets]

            # Forward pass
            output = model(inputs, targets)
            loss = sum(loss for loss in output.values())

            running_loss += loss.item() * inputs.size(0)
            print('Validate: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(i * len(inputs), 
                                                                     len(val_loader.dataset),
                                                                     100. * i / len(val_loader),
                                                                     loss.item()))
        epoch_loss = running_loss / len(val_loader.dataset)
        val_loss.append(epoch_loss)
        print('Validation Loss: {:.4f}'.format(epoch_loss))
    print()

# Save the trained model
# torch.save(model.state_dict(), 'model.pth')


Epoch 1/1
----------


AssertionError: Expected target boxes to be a tensor of shape [N, 4], got torch.Size([8, 2, 4]).

In [7]:
for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        targets = [{k: v.squeeze(1).to(device) for k,v in  targets.items()}]

        for target in targets:
            print(target["boxes"])
            break
        break

tensor([[[ 20, 125,  28,  35],
         [ 32, 136,  10,  14]],

        [[ 75, 111,  33,  35],
         [ 81, 117,  22,  22]],

        [[ 82, 111,  37,  38],
         [ 94, 121,  17,  19]],

        [[209, 109,  33,  35],
         [219, 117,  14,  17]],

        [[121, 108,  26,  28],
         [126, 114,  16,  16]],

        [[193,  98,  33,  36],
         [203, 109,  14,  14]],

        [[167, 106,  37,  39],
         [173, 113,  22,  24]],

        [[ 56,  99,  35,  39],
         [ 67, 107,  19,  19]]], device='mps:0')
