<a href="https://colab.research.google.com/github/ego-alt/segmentation-from-scratch/blob/master/instance_seg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!unzip -o -j -q -d ./full_images /content/drive/MyDrive/MP6843_img_full.zip
!unzip -o -j -q -d ./labels /content/drive/MyDrive/MP6843_inst.zip

In [4]:
import numpy as np
from os import listdir
from os.path import join
import cv2
import re

In [116]:
class ArrayMaker:
    def __init__(self, root_path):
        self.root = root_path
        self.files = {}
        self.arrdict = {}

        self.org_files()

    def main(self, dim=None, crop=None, greyscale=False):
        for name in self.files:
            for i in self.files[name]:
                if greyscale: im = cv2.imread(join(self.root, i), 0)
                else: im = cv2.imread(join(self.root, i))

                if dim: im = cv2.resize(im, dim, interpolation=cv2.INTER_AREA)
                im = np.array(im)
                if crop:
                    y, x, *_ = im.shape
                    x0 = (x - crop) // 2
                    y0 = (y - crop) // 2
                    im = im[x0:x0 + crop, y0:y0 + crop]
                self.listdict(self.arrdict, name, im)

    def org_files(self):
        regex = "^F0[1-4]_[0-9]+"
        for file in sorted(listdir(self.root)):
            if not file.startswith('.'):
                filename = re.findall(regex, file)[0]
                self.listdict(self.files, filename, file)

    def common_elements(self, other):
        stored = {k:self.arrdict[k] for k in self.arrdict if k in other}
        self.arrdict = stored

    def filtering(self, keyword):
        filtered = []
        for name in self.files:
            filename = self.files[name]
            f = [filename.index(i) for i in filename if keyword in i]
            filtered.extend(self.arrdict[name][i] for i in f)
        return filtered

    def stacking(self):
        stacked = [np.stack(self.arrdict[name],axis=-1) for name in self.arrdict]
        return stacked

    def listdict(self, dictionary, key, value):
        if key not in dictionary:
            dictionary[key] = list()
        dictionary[key].append(value)

In [117]:
lb_root = './labels'
im_root = './full_images'
dim = (696, 520)

labels = ArrayMaker(lb_root)
images = ArrayMaker(im_root)

labels.main(crop=256, greyscale=True)
images.main(dim=dim, crop=256)
images.common_elements(labels.arrdict)

w1 = images.filtering('w1')
w2 = images.filtering('w2')
labels3d = labels.stacking()
alternating = [w1[i] if i % 2 == 0 else w2[i] for i in range(len(w1))]

In [150]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
!pip install pycocotools

In [165]:
class CellImages(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, ind):
        img = self.transform(self.images[ind])
        lb = self.labels[ind]
        _, _, stack_num = lbl.shape

        masks = []
        for i in range(stack_num):
            layer = lbl[:, :, i]
            obj_ids = np.unique(layer)[1:]
            masks.extend(layer == obj_ids[:, None, None])
        mask_num = len(masks)

        boxes = []
        for i in range(mask_num):
            coord = np.where(masks[i])
            x0, x1 = np.min(coord[1]), np.max(coord[1])
            y0, y1 = np.min(coord[0]), np.max(coord[0])
            boxes.append([x0, y0, x1, y1])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        target = {}
        x_len = boxes[:,3] - boxes[:,1]
        y_len = boxes[:,2] - boxes[:,0]

        target["masks"] = masks
        target["boxes"] = boxes
        target["labels"] = torch.ones((mask_num,), dtype=torch.int64)
        target["image_id"] = torch.as_tensor([ind])
        target["area"] = x_len * y_len
        target["iscrowd"] = torch.zeros((mask_num,), dtype=torch.int64)

        return img, target

In [237]:
train_loader = DataLoader(CellImages(alternating[0:40], labels3d[0:40]), 
                          batch_size=5, shuffle=True, 
                          collate_fn=lambda x:list(zip(*x)))

test_loader = DataLoader(CellImages(alternating[40:45], labels3d[40:45]), 
                          batch_size=5, shuffle=True,
                          collate_fn=lambda x:list(zip(*x)))