# Experiments on the coco dataset for region segmentation

In [21]:

import numpy as np
import pycocotools
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import json
import cv2
import os
import sys
import tqdm
from shapely.geometry import Polygon

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader

from torchvision import io


import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor


from UNeXt.archs import UNext
from UNeXt.losses import BCEDiceLoss
from pycocotools.coco import COCO


%load_ext autoreload
%autoreload 2
from models import *
from datasets import *
from IcyXml import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
# class CocoDataset(Dataset):
#     def __init__(self, root_dir, ann_file, transforms=None):
#         self.root_dir = root_dir
#         self.transforms = transforms
#         self.coco = pycocotools.COCO(ann_file)
#         self.ids = list(self.coco.imgs.keys())
#         self.cats = self.coco.loadCats(self.coco.getCatIds())
#         self.cat_to_id = {cat['name']: cat['id'] for cat in self.cats}
#         self.id_to_cat = {cat['id']: cat['name'] for cat in self.cats}
#         self.cat_to_color = {cat['name']: cat['color'] for cat in self.cats}

#     def __len__(self):
#         return len(self.ids)

#     def __getitem__(self, idx):
#         img_id = self.ids[idx]
#         img_path = os.path.join(self.root_dir, self.coco.imgs[img_id]['file_name'])
#         img = io.read_image(img_path, io.image.ImageReadMode.RGB)
#         # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#         # img = cv2.resize(img, (256, 256))
#         img = img.astype(np.float32) / 255
#         img = torch.from_numpy(img)
#         if self.transforms:
#             img = self.transforms(img)
#         mask = self.coco.anns[self.coco.imgToAnns[img_id]]['segmentation']
#         mask = mask.astype(np.float32)
#         mask = torch.from_numpy(mask)
#         mask = mask.unsqueeze(0)
#         mask = mask.unsqueeze(0)
#         mask = mask.float()
#         return img, mask






In [23]:
# root_dir = "/home/mounib/cell-counting/datasets/"
# os.listdir(root_dir)
# # list all json files in the root directory
# ann_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.json')]
# coco = COCO(ann_files[0])

In [24]:

# id2color = {0: np.array([255, 255, 0]),
#             1: np.array([255, 0, 0]),
#             2: np.array([0, 255, 230]),
#             3: np.array([255, 0, 255]),
#             4: np.array([218,215,215]),
#             5: np.array([0,0,255]),
#             6: np.array([255,253,224]),
#             7: np.array([20,20,20]),
#             8: np.array([255,128,0]),
#             9: np.array([255,255,184]),
#             10: np.array([100,100,100]),
#             11: np.array([214,237,255])}
# def draw_masks(img, masks):
#     # annotations = np.zeros(img.shape, dtype=np.uint8)
#     plt.figure(figsize=(10, 10))
#     plt.imshow(img)
#     # plt.imshow(annotations, alpha=0.2)

#     for mask in masks:
#         plt.imshow(mask, alpha=0.5)


#     plt.show()


In [25]:
# cat = coco.getCatIds()
# categories = coco.loadCats(cat)
# nimgs = 20

# for i in range(nimgs):
#     id = list(coco.imgs.items())[i][0]
    
#     print("Image ID: ", id)

#     img = cv2.imread(os.path.join(root_dir, list(coco.imgs.items())[i][1]["file_name"]))
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     img = img.astype(np.uint8)
#     # print(img.min(), img.max())
#     annotations = coco.loadAnns(coco.getAnnIds(imgIds=id))
#     # print(id, annotations)
#     # masks = np.zeros((nimgs, len(cat), img.shape[0], img.shape[1]), dtype=np.uint8)
#     plt.figure(figsize=(10, 10))
#     plt.imshow(img)
#     print(list(coco.imgs.items())[i][1]["file_name"])
#     for annotation in annotations:
#         # print(annotation["segmentation"])
#         if(len(annotation["segmentation"][0]) < 4):
#             continue

#         if annotation["category_id"] != -1:
#             print(categories[annotation["category_id"]]["name"])
#             for n, anno in enumerate(annotation["segmentation"]):
#                 # print(anno)
#                 # if(n != 0):
#                 #     anno = anno[2:]
#                 poly = Polygon(np.array(anno).reshape((-1, 2)), closed=True, edgecolor="#000000", facecolor=id2color[annotation["category_id"]]/255, linewidth=5, alpha=0.4)
#                 plt.gca().add_patch(poly)
#                 # break
#         # color = id2color[annotation["category_id"]]
#         # mask[annotation["category_id"]] = np.logical_or(mask[annotation["category_id"]], mask)
#     plt.show()
#     # masks = np.array([coco.annToMask(annotations[j]) for j in range(1, len(annotations)-1)])
#     # draw_masks(img, masks)


In [28]:

pl.seed_everything(42)
# root_dir="/home/mounib/cell-counting/datasets/"
# root_dir = "/Datasets/"
root_dir = "/data"
ann_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.json')]
dataset = LungTumorDataset(root_dir, ann_files[1], None, imageSize=512)

trainDataset, validDataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8), int(len(dataset) * 0.2)+1])

trainDataLoader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True, prefetch_factor=2)
validDataLoader = DataLoader(validDataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True, prefetch_factor=2)



Global seed set to 42


loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [32]:
model = UNext(len(dataset.cats), 3, False)
plModel = SegModel(model, None)
# plModel = plModel.load_from_checkpoint("lightning_logs/version_2/checkpoints/epoch=261-step=11265.ckpt", backbone=model)
plModel = plModel.load_from_checkpoint("lightning_logs/version_15/checkpoints/epoch=399-step=35175.ckpt", backbone=model)
# for module in model.children():
#     print(module)

In [33]:
# trainer = pl.Trainer(max_epochs=400, gpus=1, resume_from_checkpoint="lightning_logs/version_14/checkpoints/epoch=13-step=2365.ckpt")
# trainer.fit(plModel, trainDataLoader, validDataLoader)



In [34]:
# for i in range(5):
#     img, mask = dataset[i]
#     dataset.draw_mask(img, mask)

# for i, batch in enumerate(dataLoader):
#     print(i)

In [75]:
# for i in range(len(dataset)):
#     img, mask = dataset[i]
#     print(img.shape, mask.shape)
    # extract the contours from a mask
def extract_contours(masks, eps=0.002):
    """
    Extract the contours from a mask
    :param masks: a list of masks one for each channel"""
    contours = []
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(20,20))
    for i, mask in enumerate(masks):


        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)


        externals, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        name = dataset.id2cat[i]
        color = dataset.name2color[name]
        if len(externals) > 0:
            # print(len(externals))
            contours += list((cv2.approxPolyDP((external).squeeze(1), eps * cv2.arcLength((external).squeeze(1), True), True).squeeze(1), name, color) for external in externals)
    
    return contours

In [76]:
plModel.eval()
device = plModel.device
dataset.transforms = None
dataset.imageSize = 512

with torch.no_grad():

    for i in tqdm.tqdm(range(len(dataset))):
        img, gtMask = dataset[i]
        # print(img.shape)
        out = plModel(img.unsqueeze(0).to(device))
        out = torch.sigmoid(out)
        out = out > 0.5
        out = out.cpu().squeeze(0)

        # print(gtMask.min(), gtMask.max())

        # figure = plt.figure(figsize=(25, 10))
        # plt.subplot(1, 2, 1)
        # plt.title("Prediction")
        # dataset.draw_mask(img, out, gtMask)
        # plt.subplot(1, 2, 2)
        # plt.title("ground truth")


        masks = out.cpu().numpy().astype(np.uint8)
        # print(masks.dtype, masks.min(), masks.max(), masks.shape)
        contours = extract_contours(masks)


        img_id = dataset.ids[i]
        # img_path = os.path.join(dataset.root_dir, dataset.coco.imgs[img_id]['file_name'])
        img_path = dataset.coco.imgs[img_id]['file_name']
        annot_file = ".".join(img_path.split('.')[:-1]) + ".xml"
        # print(img_path, annot_file)

        icyFile = IcyXml(root_dir, annot_file)


        for i, (contour, name, color) in enumerate(contours):
            # print(name)
            icyFile.addPolygon(contour, name, color)
        
        icyFile.save()
            
        #     # print(contour)
        #     pp = Polygon(contour)
        #     x, y = pp.exterior.xy
        #     plt.figure(figsize=(5,5))
        #     plt.plot(x, y, 'ro-')
            
            # plt.plot(contour[:, 0], contour[:, 1], 'ro-')
            # plt.xlim(0, img.shape[2])
            # plt.ylim(0, img.shape[1])
            # plt.show()


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 342/342 [01:21<00:00,  4.17it/s]


In [83]:
masks = out.cpu().numpy().astype(np.uint8)
print(masks.dtype, masks.min(), masks.max(), masks.shape)
contours = extract_contours(masks)

uint8 0 1 (12, 512, 512)
1
1
