<a href="https://colab.research.google.com/github/ego-alt/segmentation-from-scratch/blob/master/instance_seg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!unzip -o -j -q -d ./full_images /content/drive/MyDrive/MP6843_img_full.zip
!unzip -o -j -q -d ./labels /content/drive/MyDrive/MP6843_ins.zip

In [3]:
import numpy as np
from os import listdir
from os.path import join
import cv2
import re

In [43]:
class ArrayMaker:
    def __init__(self, root_path):
        self.root = root_path
        self.files = {}
        self.arrdict = {}

        self.org_files()

    def org_files(self):
        regex = "^F0[1-4]_[0-9]+"
        for file in sorted(listdir(self.root)):
            if not file.startswith('.'):
                filename = re.findall(regex, file)[0]
                if filename not in self.files:
                    self.files[filename] = list()
                self.files[filename].append(file)

    def main(self, dim=None, length=None):
        for name in self.files:
            for i in self.files[name]:
                im = cv2.imread(join(self.root, i))
                if dim: im = cv2.resize(im, dim, interpolation=cv2.INTER_AREA)
                im = np.array(im)
                if name not in self.arrdict:
                    self.arrdict[name] = list()

                if length:
                    x, y = dim
                    x0 = (x - length) // 2
                    y0 = (y - length) // 2
                    im = im[x0:x0 + length, y0:y0 + length]
                
                self.arrdict[name].append(im)

    def common_elements(self, other):
        stored = {k:self.arrdict[k] for k in self.arrdict if k in other}
        self.arrdict = stored

In [45]:
lb_root = './labels'
labels = ArrayMaker(lb_root)
labels.main()

im_root = './full_images'
dim = (696, 520)
length = 256
images = ArrayMaker(im_root)
images.main(dim=dim, length=length)
images.common_elements(labels.arrdict)

In [5]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
dim = (696, 520)
images = np.array([cv2.resize(cv2.imread(img), dim, interpolation = cv2.INTER_AREA) for img in im_files])
labels = np.array([cv2.imread(lbl, 0) for lbl in lb_files]) / 255

In [None]:
num_classes = 2
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

In [None]:
class CellImages(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.CenterCrop(256),
        ])
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, ind):
        img = self.transform(self.images[ind])
        lbl = self.transform(self.labels[ind])

        return img, lbl