In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
"""
[V1]
* Resolution: Resized to 512x512 from 768x768
* Extract cell masks and create individual cell images (512x512)
* No random crop
* No TTA
* Update normalization mean and std with 2021 training and test sets

[V2]
* Use INTER_AREA for resize

[V3]
* Refactor to use less memory

[V4]
* Run batch cell segmentator from script

[V5]
* Add data augmentation


Note: HPA-Cell-Segmentatior assume that all input images are of the same shape!
"""

kernel_mode = False
debug = False

import sys
if kernel_mode:
    sys.path.insert(0, "../input/hpa-bestfitting-solution/src")
    sys.path.insert(0, "../input/hpa-cell-segmentation")

In [3]:
# !ls ../input/hpa-bestfitting-solution/

In [4]:
# !cp ../input/hpa-bestfitting-solution/densenet121-a639ec97.pth .
# !ls -la

In [5]:
# !pip install -q "../input/pycocotools/pycocotools-2.0-cp37-cp37m-linux_x86_64.whl"
# !pip install -q "../input/hpapytorchzoozip/pytorch_zoo-master"

In [6]:
import sys
import argparse
from tqdm import tqdm
import os
import numpy as np
import pandas as pd
import time
import random
import math
import pickle
from pickle import dump, load
import glob
import time
import collections

import torch
import torch.optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from torch.nn import DataParallel
import torch.nn.functional as F
from torch.autograd import Variable

from config.config import *
from utils.common_util import *
from networks.imageclsnet import init_network
from datasets.protein_dataset import ProteinDataset
from utils.augment_util import *
from datasets.tool import *

import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.metrics.functional import classification

from pycocotools import _mask as coco_mask
import typing as t
import base64
import zlib

import cv2
from PIL import Image
import imagehash

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

pd.options.display.max_columns = None
pd.options.display.max_rows = 100

import gc
gc.enable()

rand_seed = 1120

print(f"PyTorch Version: {torch.__version__}")
print(f"PyTorch Lightning Version: {pl.__version__}")

run on 2fb1688be1a5
PyTorch Version: 1.6.0+cu101
PyTorch Lightning Version: 1.1.1


In [7]:
if kernel_mode:
    dataset_folder = "/kaggle/input/hpa-single-cell-image-classification"
    bestfitting_folder = "/kaggle/input/hpa-bestfitting-solution"
    test_image_folder = f"{dataset_folder}/test/"
    cell_mask_folder = "/kaggle/working/test_cell_masks"
    NUC_MODEL = "/kaggle/input/hpa-cell-segmentation/dpn_unet_nuclei_v1.pth"
    CELL_MODEL = "/kaggle/input/hpa-cell-segmentation/dpn_unet_cell_3ch_v1.pth"
else:
    dataset_folder = "/workspace/Kaggle/HPA/hpa_2020"
    bestfitting_folder = "/workspace/Github/HPA-competition-solutions/bestfitting"
    test_image_folder = f"{dataset_folder}/test/"
    cell_mask_folder = f"{dataset_folder}/test_cell_masks"
    NUC_MODEL = "/workspace/Github/HPA-Cell-Segmentation/dpn_unet_nuclei_v1.pth"
    CELL_MODEL = "/workspace/Github/HPA-Cell-Segmentation/dpn_unet_cell_3ch_v1.pth"

model_folder = "external_crop512_focal_slov_hardlog_class_densenet121_dropout_i768_aug2_5folds"

# image_size = 2048
# image_size = 768
crop_size = 512

batch_size = 8 if kernel_mode else 4
# batch_size = 4
num_workers = 2 if kernel_mode else 3

# scale_factor = 1.0
# scale_factor = 0.1
scale_factor = 0.25
confidence_threshold = 0.5

In [8]:
old_classes = {
    0: 'Nucleoplasm',
    1: 'Nuclear membrane',
    2: 'Nucleoli',
    3: 'Nucleoli fibrillar center',
    4: 'Nuclear speckles',
    5: 'Nuclear bodies',
    6: 'Endoplasmic reticulum',
    7: 'Golgi apparatus',
    8: 'Peroxisomes',
    9: 'Endosomes',
    10: 'Lysosomes',
    11: 'Intermediate filaments',
    12: 'Actin filaments',
    13: 'Focal adhesion sites',
    14: 'Microtubules',
    15: 'Microtubule ends',
    16: 'Cytokinetic bridge',
    17: 'Mitotic spindle',
    18: 'Microtubule organizing center',
    19: 'Centrosome',
    20: 'Lipid droplets',
    21: 'Plasma membrane',
    22: 'Cell junctions',
    23: 'Mitochondria',
    24: 'Aggresome',
    25: 'Cytosol',
    26: 'Cytoplasmic bodies',
    27: 'Rods & rings'
}
old_class_indices = {v: k for k, v in old_classes.items()}

# All label names in the public HPA and their corresponding index.
all_locations = dict({
    "Nucleoplasm": 0,
    "Nuclear membrane": 1,
    "Nucleoli": 2,
    "Nucleoli fibrillar center": 3,
    "Nuclear speckles": 4,
    "Nuclear bodies": 5,
    "Endoplasmic reticulum": 6,
    "Golgi apparatus": 7,
    "Intermediate filaments": 8,
    "Actin filaments": 9,
    "Focal adhesion sites": 9,
    "Microtubules": 10,
    "Mitotic spindle": 11,
    "Centrosome": 12,
    "Centriolar satellite": 12,
    "Plasma membrane": 13,
    "Cell Junctions": 13,
    "Mitochondria": 14,
    "Aggresome": 15,
    "Cytosol": 16,
    "Vesicles": 17,
    "Peroxisomes": 17,
    "Endosomes": 17,
    "Lysosomes": 17,
    "Lipid droplets": 17,
    "Cytoplasmic bodies": 17,
    "Rods & rings": 18,
    # markpeng
    "No staining": 18,
})

old_class_mappings = {}
for i, (k, v) in enumerate(old_class_indices.items()):
    if k in all_locations:
        old_class_mappings[v] = all_locations[k]
    else:
        # No staining
        old_class_mappings[v] = 18
assert len(old_class_mappings.values()) == len(old_classes.values())
print(old_class_mappings)

{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 17, 9: 17, 10: 17, 11: 8, 12: 9, 13: 9, 14: 10, 15: 18, 16: 18, 17: 11, 18: 18, 19: 12, 20: 17, 21: 13, 22: 18, 23: 14, 24: 15, 25: 16, 26: 17, 27: 18}


In [9]:
!ls {dataset_folder}

inference	       test		test_tfrecords	train.csv
sample_submission.csv  test_cell_masks	train		train_tfrecords


In [10]:
train_df = pd.read_csv(f"{dataset_folder}/train.csv")
submit_df = pd.read_csv(f"{dataset_folder}/sample_submission.csv")

In [11]:
print(train_df.shape)
train_df.head()

(21806, 2)


Unnamed: 0,ID,Label
0,5c27f04c-bb99-11e8-b2b9-ac1f6b6435d0,8|5|0
1,5fb643ee-bb99-11e8-b2b9-ac1f6b6435d0,14|0
2,60b57878-bb99-11e8-b2b9-ac1f6b6435d0,6|1
3,5c1a898e-bb99-11e8-b2b9-ac1f6b6435d0,16|10
4,5b931256-bb99-11e8-b2b9-ac1f6b6435d0,14|0


In [12]:
print(submit_df.shape, submit_df.ImageWidth.min(), submit_df.ImageWidth.max())
submit_df.head()

(559, 4) 1728 3072


Unnamed: 0,ID,ImageWidth,ImageHeight,PredictionString
0,0040581b-f1f2-4fbe-b043-b6bfea5404bb,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
1,004a270d-34a2-4d60-bbe4-365fca868193,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
2,00537262-883c-4b37-a3a1-a4931b6faea5,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
3,00c9a1c9-2f06-476f-8b0d-6d01032874a2,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
4,0173029a-161d-40ef-af28-2342915b22fb,3072,3072,0 1 eNoLCAgIsAQABJ4Beg==


In [13]:
colors = ["red", "green", "blue", "yellow"]

test_ids = submit_df["ID"].values.tolist()
print(len(test_ids))

# Estimated number of private test images (RGBY): 2236 x 2.3 ~= 5143 (for 9 hours we have 6.2 secs per image)
# Estimated number of private test images: 559 x 2.3 ~= 1286 (for 9 hours we have 25.2 secs per image)

559


## Utility Functions

In [14]:
# Reference: https://www.kaggle.com/dschettler8845/hpa-cellwise-classification-inference/notebook
def binary_mask_to_ascii(mask, mask_val=1):
    """Converts a binary mask into OID challenge encoding ascii text."""
    mask = np.where(mask == mask_val, 1, 0).astype(np.bool)

    # check input mask --
    if mask.dtype != np.bool:
        raise ValueError(
            f"encode_binary_mask expects a binary mask, received dtype == {mask.dtype}"
        )

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            f"encode_binary_mask expects a 2d mask, received shape == {mask.shape}"
        )

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str.decode()


def rle_encoding(img, mask_val=1):
    """
    Turns our masks into RLE encoding to easily store them
    and feed them into models later on
    https://en.wikipedia.org/wiki/Run-length_encoding
    
    Args:
        img (np.array): Segmentation array
        mask_val (int): Which value to use to create the RLE
        
    Returns:
        RLE string
    
    """
    dots = np.where(img.T.flatten() == mask_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b > prev + 1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b

    return ' '.join([str(x) for x in run_lengths])


def rle_to_mask(rle_string, height, width):
    """ Convert RLE sttring into a binary mask 
    
    Args:
        rle_string (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array of the binary segmentation mask for a given cell
    """
    rows, cols = height, width
    rle_numbers = [int(num_string) for num_string in rle_string.split(' ')]
    rle_pairs = np.array(rle_numbers).reshape(-1, 2)
    img = np.zeros(rows * cols, dtype=np.uint8)
    for index, length in rle_pairs:
        index -= 1
        img[index:index + length] = 255
    img = img.reshape(cols, rows)
    img = img.T
    return img


def create_segmentation_maps(list_of_image_lists, segmentator, batch_size=8):
    """ Function to generate segmentation maps using CellSegmentator tool 
    
    Args:
        list_of_image_lists (list of lists):
            - [[micro-tubules(red)], [endoplasmic-reticulum(yellow)], [nucleus(blue)]]
        batch_size (int): Batch size to use in generating the segmentation masks
        
    Returns:
        List of lists containing RLEs for all the cells in all images
    """

    all_mask_rles = {}
    for i in tqdm(range(0, len(list_of_image_lists[0]), batch_size),
                  total=len(list_of_image_lists[0]) // batch_size):

        # Get batch of images
        sub_images = [
            img_channel_list[i:i + batch_size]
            for img_channel_list in list_of_image_lists
        ]  # 0.000001 seconds

        # Do segmentation
        cell_segmentations = segmentator.pred_cells(sub_images)
        nuc_segmentations = segmentator.pred_nuclei(sub_images[2])

        # post-processing
        for j, path in enumerate(sub_images[0]):
            img_id = path.replace("_red.png", "").rsplit("/", 1)[1]
            nuc_mask, cell_mask = label_cell(nuc_segmentations[j],
                                             cell_segmentations[j])
            new_name = os.path.basename(path).replace('red', 'mask')
            all_mask_rles[img_id] = [
                rle_encoding(cell_mask, mask_val=k)
                for k in range(1,
                               np.max(cell_mask) + 1)
            ]
    return all_mask_rles


def get_img_list(img_dir, return_ids=False, sub_n=None):
    """ Get image list in the format expected by the CellSegmentator tool """
    if sub_n is None:
        sub_n = len(glob(img_dir + '/' + f'*_red.png'))
    if return_ids:
        images = [
            sorted(glob(img_dir + '/' + f'*_{c}.png'))[:sub_n]
            for c in ["red", "yellow", "blue"]
        ]
        return [
            x.replace("_red.png", "").rsplit("/", 1)[1] for x in images[0]
        ], images
    else:
        return [
            sorted(glob(img_dir + '/' + f'*_{c}.png'))[:sub_n]
            for c in ["red", "yellow", "blue"]
        ]


def get_contour_bbox_from_rle(
    rle,
    width,
    height,
    return_mask=True,
):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        rle (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    mask = rle_to_mask(rle, height, width).copy()
    cnts = grab_contours(
        cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE))
    x, y, w, h = cv2.boundingRect(cnts[0])

    if return_mask:
        return (x, y, x + w, y + h), mask
    else:
        return (x, y, x + w, y + h)


def get_contour_bbox_from_raw(raw_mask):
    """ Get bbox of contour as `xmin ymin xmax ymax`
    
    Args:
        raw_mask (nparray): Numpy array containing segmentation mask information
    
    Returns:
        Numpy array for a cell bounding box coordinates
    """
    cnts = grab_contours(
        cv2.findContours(raw_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE))
    xywhs = [cv2.boundingRect(cnt) for cnt in cnts]
    xys = [(xywh[0], xywh[1], xywh[0] + xywh[2], xywh[1] + xywh[3])
           for xywh in xywhs]
    return sorted(xys, key=lambda x: (x[1], x[0]))


def pad_to_square(a):
    """ Pad an array `a` evenly until it is a square """
    if a.shape[1] > a.shape[0]:  # pad height
        n_to_add = a.shape[1] - a.shape[0]
        top_pad = n_to_add // 2
        bottom_pad = n_to_add - top_pad
        a = np.pad(a, [(top_pad, bottom_pad), (0, 0), (0, 0)], mode='constant')

    elif a.shape[0] > a.shape[1]:  # pad width
        n_to_add = a.shape[0] - a.shape[1]
        left_pad = n_to_add // 2
        right_pad = n_to_add - left_pad
        a = np.pad(a, [(0, 0), (left_pad, right_pad), (0, 0)], mode='constant')
    else:
        pass
    return a


def cut_out_cells(rgby,
                  rles,
                  resize_to=(256, 256),
                  square_off=True,
                  return_masks=False,
                  from_raw=True):
    """ Cut out the cells as padded square images 
    
    Args:
        rgby (np.array): 4 Channel image to be cut into tiles
        rles (list of RLE strings): List of run length encoding containing 
            segmentation mask information
        resize_to (tuple of ints, optional): The square dimension to resize the image to
        square_off (bool, optional): Whether to pad the image to a square or not
        
    Returns:
        list of square arrays representing squared off cell images
    """
    w, h = rgby.shape[:2]
    contour_bboxes = [
        get_contour_bbox(rle, w, h, return_mask=return_masks) for rle in rles
    ]
    if return_masks:
        masks = [x[-1] for x in contour_bboxes]
        contour_bboxes = [x[:-1] for x in contour_bboxes]

    arrs = [
        rgby[bbox[1]:bbox[3], bbox[0]:bbox[2], ...] for bbox in contour_bboxes
    ]
    if square_off:
        arrs = [pad_to_square(arr) for arr in arrs]

    if resize_to is not None:
        arrs = [
            cv2.resize(pad_to_square(arr).astype(np.float32),
                       resize_to,
                       interpolation=cv2.INTER_CUBIC) \
            for arr in arrs
        ]
    if return_masks:
        return arrs, masks
    else:
        return arrs


def grab_contours(cnts):
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(
            ("Contours tuple must have length 2 or 3, "
             "otherwise OpenCV changed their cv2.findContours return "
             "signature yet again. Refer to OpenCV's documentation "
             "in that case"))

    # return the actual contours array
    return cnts

In [15]:
# https://www.kaggle.com/c/human-protein-atlas-image-classification/discussion/72534
def generate_hash(img_dir,
                  colors,
                  dataset='train',
                  imread_func=None,
                  is_update=False):
    meta = meta.copy()
    hash_maps = {}
    for color in colors:
        hash_maps[color] = []
        for idx in tqdm(range(len(meta)), desc='train %s' % color):
            img = imread_func(img_dir, meta.iloc[idx][ID], color)
            hash = imagehash.phash(img)
            hash_maps[color].append(hash)

    for color in colors:
        meta[color] = hash_maps[color]

    return meta


def calc_hash(params):
    color, threshold, base_test_hash1, base_test_hash2, test_ids1, test_ids2 = params

    test_hash1 = base_test_hash1.reshape(1, -1)  # 1*m

    test_idxes_list1 = []
    test_idxes_list2 = []
    hash_list = []

    step = 5
    for test_idx in tqdm(range(0, len(base_test_hash2), step), desc=color):
        test_hash2 = base_test_hash2[test_idx:test_idx + step].reshape(
            -1, 1)  # n*1
        hash = test_hash2 - test_hash1  # n*m
        test_idxes2, test_idxes1 = np.where(hash <= threshold)
        hash = hash[test_idxes2, test_idxes1]

        test_idxes2 = test_idxes2 + test_idx

        test_idxes_list1.extend(test_idxes1.tolist())
        test_idxes_list2.extend(test_idxes2.tolist())
        hash_list.extend(hash.tolist())

    df = pd.DataFrame({
        'Test1': test_ids1[test_idxes_list1],
        'Test2': test_ids2[test_idxes_list2],
        'Sim%s' % color[:1].upper(): hash_list
    })
    df = df[df['Test1'] != df['Test2']]
    return df

## Preprocessing

### Extract Cell Segmentations as Numpy Files

In [16]:
%%writefile hpa_cell_segment.py

kernel_mode = False

from tqdm import tqdm
import os
import numpy as np
import pandas as pd
import cv2
import gc

IMAGE_SIZES = [1728, 2048, 3072, 4096]
if kernel_mode:
    dataset_folder = "/kaggle/input/hpa-single-cell-image-classification"
    img_dir = f"{dataset_folder}/test"
    output_folder = "/kaggle/working/test_cell_masks"
    NUC_MODEL = "/kaggle/input/hpa-cell-segmentation/dpn_unet_nuclei_v1.pth"
    CELL_MODEL = "/kaggle/input/hpa-cell-segmentation/dpn_unet_cell_3ch_v1.pth"
    BATCH_SIZE = {1728: 24, 2048: 24, 3072: 12, 4096: 12}
else:
    dataset_folder = "/workspace/Kaggle/HPA/hpa_2020"
    img_dir = f"{dataset_folder}/test"
    output_folder = f"{dataset_folder}/test_cell_masks"
    NUC_MODEL = "/workspace/Github/HPA-Cell-Segmentation/dpn_unet_nuclei_v1.pth"
    CELL_MODEL = "/workspace/Github/HPA-Cell-Segmentation/dpn_unet_cell_3ch_v1.pth"
    BATCH_SIZE = {1728: 18, 2048: 18, 3072: 8, 4096: 8}

os.makedirs(output_folder, exist_ok=True)

submit_df = pd.read_csv(f"{dataset_folder}/sample_submission.csv")

predict_df_1728 = submit_df[submit_df.ImageWidth == IMAGE_SIZES[0]]
predict_df_2048 = submit_df[submit_df.ImageWidth == IMAGE_SIZES[1]]
predict_df_3072 = submit_df[submit_df.ImageWidth == IMAGE_SIZES[2]]
predict_df_4096 = submit_df[submit_df.ImageWidth == IMAGE_SIZES[3]]

predict_ids_1728 = predict_df_1728.ID.to_list()
predict_ids_2048 = predict_df_2048.ID.to_list()
predict_ids_3072 = predict_df_3072.ID.to_list()
predict_ids_4096 = predict_df_4096.ID.to_list()

import os
import sys

import cv2
import imageio
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F

from skimage import transform, util
"""Shared constants for the HPA Cell Segmentation package."""

NUCLEI_MODEL_URL = (
    "https://zenodo.org/record/4665863/files/dpn_unet_nuclei_v1.pth")

MULTI_CHANNEL_CELL_MODEL_URL = (
    "https://zenodo.org/record/4665863/files/dpn_unet_cell_3ch_v1.pth")

TWO_CHANNEL_CELL_MODEL_URL = (
    "https://zenodo.org/record/4665863/files/dpn_unet_cell_v2.pth")
"""Utility functions for the HPA Cell Segmentation package."""
import os.path
import urllib
import zipfile

import numpy as np
import scipy.ndimage as ndi
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)

HIGH_THRESHOLD = 0.4
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25


def download_with_url(url_string, file_path, unzip=False):
    """Download file with a link."""
    with urllib.request.urlopen(url_string) as response, open(
            file_path, "wb") as out_file:
        data = response.read()  # a `bytes` object
        out_file.write(data)

    if unzip:
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            zip_ref.extractall(os.path.dirname(file_path))


def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image


def label_nuclei(nuclei_pred):
    """Return the labeled nuclei mask data array.
    This function works best for Human Protein Atlas cell images with
    predictions from the CellSegmentator class.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    Returns:
    nuclei-label -- An array with unique numbers for each found nuclei
                    in the nuclei_pred. A value of 0 in the array is
                    considered background, and the values 1-n is the
                    areas of the cells 1-n.
    """
    img_copy = np.copy(nuclei_pred[..., 2])
    borders = (nuclei_pred[..., 1] > 0.05).astype(np.uint8)
    m = img_copy * (1 - borders)

    img_copy[m <= LOW_THRESHOLD] = 0
    img_copy[m > LOW_THRESHOLD] = 1
    img_copy = img_copy.astype(np.bool)
    img_copy = binary_erosion(img_copy)
    # TODO: Add parameter for remove small object size for
    #       differently scaled images.
    # img_copy = remove_small_objects(img_copy, 500)
    img_copy = img_copy.astype(np.uint8)
    markers = measure.label(img_copy).astype(np.uint32)

    mask_img = np.copy(nuclei_pred[..., 2])
    mask_img[mask_img <= HIGH_THRESHOLD] = 0
    mask_img[mask_img > HIGH_THRESHOLD] = 1
    mask_img = mask_img.astype(np.bool)
    mask_img = remove_small_holes(mask_img, 1000)
    # TODO: Figure out good value for remove small objects.
    # mask_img = remove_small_objects(mask_img, 8)
    mask_img = mask_img.astype(np.uint8)
    nuclei_label = segmentation.watershed(mask_img,
                                          markers,
                                          mask=mask_img,
                                          watershed_line=True)
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = measure.label(nuclei_label)
    return nuclei_label


def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool)
        img_copy = remove_small_objects(
            img_copy, small_object_size_cutoff).astype(np.uint8)

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool)
        mask_img = remove_small_holes(mask_img, 1000)
        mask_img = remove_small_objects(mask_img, 8).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(mask_img,
                                               markers,
                                               mask=mask_img,
                                               watershed_line=True)
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=500,
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(
        0.22,
        filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[...,
                                                                           2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 5500).astype(np.uint8)
    selem = disk(6)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        ) > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 5500)
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label


def label_cell2(cell_pred):
    """label cell with only cell predition"""
    cell_pred = cell_pred / cell_pred.max()
    size = cell_pred.shape[0]
    img = cell_pred.copy()
    cell_pred[..., 2] = filters.gaussian(cell_pred[..., 2], sigma=8)
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2]))
    threshold_value1 = max(0.6, filters.threshold_otsu(img[..., 2]))
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2],
        np.logical_and(np.invert(np.asarray(cell_pred[..., 1] > 0.01)),
                       cell_pred[..., 2] > threshold_value))
    cell_region1 = np.multiply(
        img[..., 2] > threshold_value1,
        np.invert(np.asarray(cell_pred[..., 1] > 0.01)),
    )
    cell_region_eroded = morphology.erosion(cell_region1,
                                            morphology.square(25))
    cell_region_eroded = np.asarray(cell_region_eroded, dtype=np.uint8)
    cell_region_eroded = ndi.label(cell_region_eroded)[0]
    remove_size_ratio = int((size / 512)**2)
    cell_region_eroded = remove_small_objects(cell_region_eroded,
                                              10 * remove_size_ratio)
    cell_region_eroded = np.asarray(cell_region_eroded > 0, dtype=np.uint8)
    distance = np.clip(cell_pred[..., 2], threshold_value, cell_pred[..., 2])
    local_maxi = feature.peak_local_max(cell_region_eroded,
                                        indices=False,
                                        footprint=np.ones((1, 1)))
    markers = ndi.label(local_maxi)[0]
    cell_label = segmentation.watershed(-distance, markers, mask=cell_region)
    cell_label = remove_small_objects(cell_label, 1000 *
                                      remove_size_ratio).astype(np.uint8)
    selem = disk(6)
    cell_label = closing(cell_label, selem)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.logical_or(
        cell_label > 0,
        cell_pred[..., 1] > 0.1,
    )
    sk = np.asarray(sk, dtype=np.uint8)
    cell_label = segmentation.watershed(-sk, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = measure.label(cell_label)
    #cell_label = np.asarray(cell_label, dtype=np.uint16)

    return cell_label


NORMALIZE = {
    "mean": [124 / 255, 117 / 255, 104 / 255],
    "std": [1 / (0.0167 * 255)] * 3
}


class CellSegmentator(object):
    """Uses pretrained DPN-Unet models to segment cells from images."""
    def __init__(
        self,
        nuclei_model="./nuclei_model.pth",
        cell_model="./cell_model.pth",
        scale_factor=0.25,
        device="cuda",
        padding=False,
        multi_channel_model=True,
    ):
        """Class for segmenting nuclei and whole cells from confocal microscopy images.
        It takes lists of images and returns the raw output from the
        specified segmentation model. Models can be automatically
        downloaded if they are not already available on the system.
        When working with images from the Huan Protein Cell atlas, the
        outputs from this class' methods are well combined with the
        label functions in the utils module.
        Note that for cell segmentation, there are two possible models
        available. One that works with 2 channeled images and one that
        takes 3 channels.
        Keyword arguments:
        nuclei_model -- A loaded torch nuclei segmentation model or the
                        path to a file which contains such a model.
                        If the argument is a path that points to a non-existant file,
                        a pretrained nuclei_model is going to get downloaded to the
                        specified path (default: './nuclei_model.pth').
        cell_model -- A loaded torch cell segmentation model or the
                      path to a file which contains such a model.
                      The cell_model argument can be None if only nuclei
                      are to be segmented (default: './cell_model.pth').
        scale_factor -- How much to scale images before they are fed to
                        segmentation models. Segmentations will be scaled back
                        up by 1/scale_factor to match the original image
                        (default: 0.25).
        device -- The device on which to run the models.
                  This should either be 'cpu' or 'cuda' or pointed cuda
                  device like 'cuda:0' (default: 'cuda').
        padding -- Whether to add padding to the images before feeding the
                   images to the network. (default: False).
        multi_channel_model -- Control whether to use the 3-channel cell model or not.
                               If True, use the 3-channel model, otherwise use the
                               2-channel version (default: True).
        """
        if device != "cuda" and device != "cpu" and "cuda" not in device:
            raise ValueError(f"{device} is not a valid device (cuda/cpu)")
        if device != "cpu":
            try:
                assert torch.cuda.is_available()
            except AssertionError:
                print("No GPU found, using CPU.", file=sys.stderr)
                device = "cpu"
        self.device = device

        if isinstance(nuclei_model, str):
            if not os.path.exists(nuclei_model):
                print(
                    f"Could not find {nuclei_model}. Downloading it now",
                    file=sys.stderr,
                )
                download_with_url(NUCLEI_MODEL_URL, nuclei_model)
            nuclei_model = torch.load(nuclei_model,
                                      map_location=torch.device(self.device))
        if isinstance(nuclei_model, torch.nn.DataParallel) and device == "cpu":
            nuclei_model = nuclei_model.module

        self.nuclei_model = nuclei_model.to(self.device)

        self.multi_channel_model = multi_channel_model
        if isinstance(cell_model, str):
            if not os.path.exists(cell_model):
                print(f"Could not find {cell_model}. Downloading it now",
                      file=sys.stderr)
                if self.multi_channel_model:
                    download_with_url(MULTI_CHANNEL_CELL_MODEL_URL, cell_model)
                else:
                    download_with_url(TWO_CHANNEL_CELL_MODEL_URL, cell_model)
            cell_model = torch.load(cell_model,
                                    map_location=torch.device(self.device))
        self.cell_model = cell_model.to(self.device)
        self.scale_factor = scale_factor
        self.padding = padding

    def _image_conversion(self, images):
        """Convert/Format images to RGB image arrays list for cell predictions.
        Intended for internal use only.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                 pattern if with er channel input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     [er_path0/image_array0, er_path1/image_array1, ...],
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
                 or if without er input,
                 [
                     [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                     None,
                     [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                 ]
        """
        microtubule_imgs, er_imgs, nuclei_imgs = images
        if self.multi_channel_model:
            if not isinstance(er_imgs, list):
                raise ValueError(
                    "Please speicify the image path(s) for er channels!")
        else:
            if not er_imgs is None:
                raise ValueError(
                    "second channel should be None for two channel model predition!"
                )

        if not isinstance(microtubule_imgs, list):
            raise ValueError("The microtubule images should be a list")
        if not isinstance(nuclei_imgs, list):
            raise ValueError("The microtubule images should be a list")

        if er_imgs:
            if not len(microtubule_imgs) == len(er_imgs) == len(nuclei_imgs):
                raise ValueError(
                    "The lists of images needs to be the same length")
        else:
            if not len(microtubule_imgs) == len(nuclei_imgs):
                raise ValueError(
                    "The lists of images needs to be the same length")

        if not all(isinstance(item, np.ndarray) for item in microtubule_imgs):
            microtubule_imgs = [
                os.path.expanduser(item)
                for _, item in enumerate(microtubule_imgs)
            ]
            nuclei_imgs = [
                os.path.expanduser(item) for _, item in enumerate(nuclei_imgs)
            ]

            microtubule_imgs = list(
                map(lambda item: imageio.imread(item), microtubule_imgs))
            nuclei_imgs = list(
                map(lambda item: imageio.imread(item), nuclei_imgs))
            if er_imgs:
                er_imgs = [
                    os.path.expanduser(item) for _, item in enumerate(er_imgs)
                ]
                er_imgs = list(map(lambda item: imageio.imread(item), er_imgs))

        if not er_imgs:
            er_imgs = [
                np.zeros(item.shape, dtype=item.dtype)
                for _, item in enumerate(microtubule_imgs)
            ]
        cell_imgs = list(
            map(
                lambda item: np.dstack((item[0], item[1], item[2])),
                list(zip(microtubule_imgs, er_imgs, nuclei_imgs)),
            ))

        return cell_imgs

    def pred_nuclei(self, images, bs=24):
        """Predict the nuclei segmentation.
        Keyword arguments:
        images -- A list of image arrays or a list of paths to images.
                  If as a list of image arrays, the images could be 2d images
                  of nuclei data array only, or must have the nuclei data in
                  the blue channel; If as a list of file paths, the images
                  could be RGB image files or gray scale nuclei image file
                  paths.
        Returns:
        predictions -- A list of predictions of nuclei segmentation for each nuclei image.
        """
        def _preprocess(image):
            if isinstance(image, str):
                image = imageio.imread(image)
            self.target_shape = image.shape
            if len(image.shape) == 2:
                image = np.dstack((image, image, image))
            image = transform.rescale(image,
                                      self.scale_factor,
                                      multichannel=True)
            nuc_image = np.dstack((image[..., 2], image[..., 2], image[...,
                                                                       2]))
            if self.padding:
                rows, cols = nuc_image.shape[:2]
                self.scaled_shape = rows, cols
                nuc_image = cv2.copyMakeBorder(
                    nuc_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            nuc_image = nuc_image.transpose([2, 0, 1])
            return nuc_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.nuclei_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        # preprocessed_imgs = map(_preprocess, images)
        # predictions = map(lambda x: _segment_helper([x]), preprocessed_imgs)
        # predictions = map(lambda x: x.to("cpu").numpy()[0], predictions)
        # predictions = map(util.img_as_ubyte, predictions)
        # predictions = list(map(self._restore_scaling_padding, predictions))
        preprocessed_imgs = list(map(_preprocess, images))
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i + bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(util.img_as_ubyte, predictions)
        predictions = list(map(self._restore_scaling_padding, predictions))
        return predictions

    def _restore_scaling_padding(self, n_prediction):
        """Restore an image from scaling and padding.
        This method is intended for internal use.
        It takes the output from the nuclei model as input.
        """
        n_prediction = n_prediction.transpose([1, 2, 0])
        if self.padding:
            n_prediction = n_prediction[32:32 + self.scaled_shape[0],
                                        32:32 + self.scaled_shape[1], ...]
        n_prediction[..., 0] = 0
        if not self.scale_factor == 1:
            n_prediction = cv2.resize(
                n_prediction,
                (self.target_shape[0], self.target_shape[1]),
                interpolation=cv2.INTER_AREA,
            )
        return n_prediction

    def pred_cells(self, images, precombined=False, bs=24):
        """Predict the cell segmentation for a list of images.
        Keyword arguments:
        images -- list of lists of image paths/arrays. It should following the
                  pattern if with er channel input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      [er_path0/image_array0, er_path1/image_array1, ...],
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  or if without er input,
                  [
                      [microtubule_path0/image_array0, microtubule_path1/image_array1, ...],
                      None,
                      [nuclei_path0/image_array0, nuclei_path1/image_array1, ...]
                  ]
                  The ER channel is required when multichannel is True
                  and required to be None when multichannel is False.
                  The images needs to be of the same size.
        precombined -- If precombined is True, the list of images is instead supposed to be
                       a list of RGB numpy arrays (default: False).
        Returns:
        predictions -- a list of predictions of cell segmentations.
        """
        def _preprocess(image):
            self.target_shape = image.shape
            if not len(image.shape) == 3:
                raise ValueError("image should has 3 channels")
            cell_image = transform.rescale(image,
                                           self.scale_factor,
                                           multichannel=True)
            if self.padding:
                rows, cols = cell_image.shape[:2]
                self.scaled_shape = rows, cols
                cell_image = cv2.copyMakeBorder(
                    cell_image,
                    32,
                    (32 - rows % 32),
                    32,
                    (32 - cols % 32),
                    cv2.BORDER_REFLECT,
                )
            cell_image = cell_image.transpose([2, 0, 1])
            return cell_image

        def _segment_helper(imgs):
            with torch.no_grad():
                mean = torch.as_tensor(NORMALIZE["mean"], device=self.device)
                std = torch.as_tensor(NORMALIZE["std"], device=self.device)
                imgs = torch.tensor(imgs).float()
                imgs = imgs.to(self.device)
                imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])

                imgs = self.cell_model(imgs)
                imgs = F.softmax(imgs, dim=1)
                return imgs

        if not precombined:
            images = self._image_conversion(images)
        # preprocessed_imgs = map(_preprocess, images)
        # predictions = map(lambda x: _segment_helper([x]), preprocessed_imgs)
        # predictions = map(lambda x: x.to("cpu").numpy()[0], predictions)
        # predictions = map(self._restore_scaling_padding, predictions)
        # predictions = list(map(util.img_as_ubyte, predictions))
        preprocessed_imgs = list(map(_preprocess, images))
        predictions = []
        for i in range(0, len(preprocessed_imgs), bs):
            start = i
            end = min(len(preprocessed_imgs), i + bs)
            x = preprocessed_imgs[start:end]
            pred = _segment_helper(x).cpu().numpy()
            predictions.append(pred)
        predictions = list(np.concatenate(predictions, axis=0))
        predictions = map(self._restore_scaling_padding, predictions)
        predictions = list(map(util.img_as_ubyte, predictions))

        return predictions


segmentator = CellSegmentator(
    NUC_MODEL,
    CELL_MODEL,
    scale_factor=0.25,
    device="cuda",
    padding=True,
    multi_channel_model=True,
)


def get_segment_mask(batch_image_paths, bs=24):
    nuc_segmentations = segmentator.pred_nuclei(batch_image_paths[2],
                                                bs=bs)  # blue
    cell_segmentations = segmentator.pred_cells(batch_image_paths, bs=bs)
    batch_cell_masks = [
        label_cell(nuc_seg, cell_seg)[1].astype(np.uint8)
        for nuc_seg, cell_seg in zip(nuc_segmentations, cell_segmentations)
    ]
    return batch_cell_masks


for size_idx, submission_ids in tqdm(enumerate(
    [predict_ids_1728, predict_ids_2048, predict_ids_3072, predict_ids_4096]),
                                     total=4):
    size = IMAGE_SIZES[size_idx]
    if submission_ids == []:
        print(f"\n...SKIPPING SIZE {size} AS THERE ARE NO IMAGE IDS ...\n")
        continue
    else:
        print(f"\n...WORKING ON IMAGE IDS FOR SIZE {size} ...\n")
    for i in tqdm(range(0, len(submission_ids), BATCH_SIZE[size]),
                  total=int(np.ceil(len(submission_ids) / BATCH_SIZE[size]))):

        r, y, b = [], [], []
        image_ids = submission_ids[i:(i + BATCH_SIZE[size])]
        for img_id in image_ids:
            r.append(os.path.join(img_dir, f'{img_id}_red.png'))
            y.append(os.path.join(img_dir, f'{img_id}_yellow.png'))
            b.append(os.path.join(img_dir, f'{img_id}_blue.png'))
        batch_image_paths = [r, y, b]
        batch_cell_masks = get_segment_mask(batch_image_paths,
                                            bs=BATCH_SIZE[size])

        torch.cuda.empty_cache()
        gc.collect()

        for index, img_id in enumerate(image_ids):
            np.save(os.path.join(output_folder, f'{img_id}_cell_mask.npy'),
                    batch_cell_masks[index])

Overwriting hpa_cell_segment.py


In [17]:
# %%time
# !python hpa_cell_segment.py

### Calculate Mean and Std From Training and Test Sets

In [18]:
total_mean = [0.081018, 0.052349, 0.054012, 0.08106] # rgby
total_std = [0.133235, 0.08948, 0.143813, 0.130265]

## Load Pretrained Model from Bestfitting

In [19]:
parser = argparse.ArgumentParser(description='PyTorch Protein Classification')
parser.add_argument('--out_dir', type=str, help='destination where predicted result should be saved')
parser.add_argument('--gpu_id', default='0', type=str, help='gpu id used for predicting (default: 0)')
parser.add_argument('--arch', default='class_densenet121_dropout', type=str,
                    help='model architecture (default: class_densenet121_dropout)')
parser.add_argument('--num_classes', default=28, type=int, help='number of classes (default: 28)')
parser.add_argument('--in_channels', default=4, type=int, help='in channels (default: 4)')
parser.add_argument('--img_size', default=768, type=int, help='image size (default: 768)')
parser.add_argument('--crop_size', default=512, type=int, help='crop size (default: 512)')
parser.add_argument('--batch_size', default=32, type=int, help='train mini-batch size (default: 32)')
parser.add_argument('--workers', default=3, type=int, help='number of data loading workers (default: 3)')
parser.add_argument('--fold', default=0, type=int, help='index of fold (default: 0)')
parser.add_argument('--augment', default='default', type=str, help='test augmentation (default: default)')
parser.add_argument('--seed', default=100, type=int, help='random seed (default: 100)')
parser.add_argument('--seeds', default=None, type=str, help='predict seed')
parser.add_argument('--predict_epoch', default=None, type=int, help='number epoch to predict')

_StoreAction(option_strings=['--predict_epoch'], dest='predict_epoch', nargs=None, const=None, default=None, type=<class 'int'>, choices=None, help='number epoch to predict', metavar=None)

In [20]:
args = parser.parse_args([
    "--arch",
    "class_densenet121_dropout",
    #     "--img_size", str(image_size),
    "--crop_size",
    str(crop_size),
])
args

Namespace(arch='class_densenet121_dropout', augment='default', batch_size=32, crop_size=512, fold=0, gpu_id='0', img_size=768, in_channels=4, num_classes=28, out_dir=None, predict_epoch=None, seed=100, seeds=None, workers=3)

In [21]:
def load_model(network_path, args, print_model=False):
    # setting up the visible GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

    model_params = {}
    model_params['architecture'] = args.arch
    model_params['num_classes'] = args.num_classes
    model_params['in_channels'] = args.in_channels
    model_params['pretrained_path'] = f"{bestfitting_folder}"
    model = init_network(model_params)

    checkpoint = torch.load(network_path)
    model.load_state_dict(checkpoint['state_dict'])

    # moving network to gpu and eval mode
    # model = DataParallel(model)
    model.cuda()
    model.eval()

    if print_model:
        print(model)

    return model

## Inference

In [22]:
def read_crop_img(img):
    random_crop_size = int(np.random.uniform(self.crop_size, self.img_size))
    x = int(np.random.uniform(0, self.img_size - random_crop_size))
    y = int(np.random.uniform(0, self.img_size - random_crop_size))
    crop_img = img[x:x + random_crop_size, y:y + random_crop_size]
    return crop_img


def read_rgby(
    img_dir,
    img_id,
    random_crop=False,
):
    suffix = '.png'
    colors = ['red', 'green', 'blue', 'yellow']

    flags = cv2.IMREAD_GRAYSCALE
    rgby_img = [
        cv2.imread(opj(img_dir, img_id + '_' + color + suffix), flags)
        for color in colors
    ]
    rgby_img = np.stack(rgby_img, axis=-1)
    if random_crop and crop_size > 0:
        rgby_img = read_crop_img(rgby_img)

    return rgby_img

In [23]:
def process_image(img_id, img_dir, transform=None):
    if img_id not in global_cache:
        rgby_img = read_rgby(img_dir, img_id)
        if rgby_img[0] is None:
            print(self.img_dir, img_id)

        h, w = rgby_img.shape[:2]

        if crop_size > 0:
            if crop_size != h or crop_size != w:
                resized_rgby_img = cv2.resize(rgby_img, (crop_size, crop_size),
                                              interpolation=cv2.INTER_LINEAR)

        # TODO: add cache masks
        full_mask = np.load(f"{cell_mask_folder}/{img_id}_cell_mask.npy")

        full_mask = cv2.resize(full_mask, (crop_size, crop_size),
                               interpolation=cv2.INTER_NEAREST)

        cell_masks = [
            rle_encoding(full_mask, mask_val=k)
            for k in range(1,
                           np.max(full_mask) + 1)
        ]
        if len(cell_masks) == 0:
            print(f"No cell masks found for {img_id}")

        if transform is not None:
            resized_rgby_img = transform(resized_rgby_img)

        resized_rgby_img = resized_rgby_img / 255.0
        resized_rgby_img = image_to_tensor(resized_rgby_img)

        global_cache[img_id] = (resized_rgby_img, cell_masks)
        return resized_rgby_img, cell_masks
    else:
        print(f"Cache hit for {img_id}!")
        return global_cache[img_id]

In [24]:
def collate_cells_fn(x):
    images, masks = [], []
    # For each full image, extract cell images
    image = x[0]

    for rle_string in x[1]:
        cell_mask = rle_to_mask(rle_string, crop_size, crop_size)
        # Important: set 255 to 1
        cell_mask[cell_mask > 0] = 1

        cell_image = torch.clone(image)
        for i in range(4):
            cell_image[i, ...] = cell_image[i, ...] * cell_mask
        images.append(cell_image.unsqueeze(0))

        cell_mask = rle_encoding(cell_mask, mask_val=1)
        masks.append(cell_mask)

    images = torch.cat(images)
    return images, masks

In [25]:
def predict(df, seed, fold=0, generate_meta=False):
    all_probs = []
    all_meta = {}

    seed_everything(rand_seed + 1000 * seed)
    network_path = f"{bestfitting_folder}/{model_folder}/fold{fold}/final.pth"
    model = load_model(network_path, args)

    global_processed = 0
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        id = row["ID"]
        width = row["ImageWidth"]
        height = row["ImageHeight"]

        # Data augmentation
        aug_probs = None
        for augment in augment_list:
            transform = eval(f"augment_{augment}")
            images, masks = process_image(id,
                                          test_image_folder,
                                          transform=transform)
            images, masks = collate_cells_fn((images, masks))

            cell_probs = []
            processed_count = 0
            for batch_i in range(0, len(masks), batch_size):
                batch_images = images[batch_i:batch_i + batch_size, ...]
                batch_images = Variable(batch_images.cuda(), volatile=True)
                outputs = model(batch_images)
                logits = outputs

                probs = F.sigmoid(logits).data
                probs = probs.detach().cpu().numpy().tolist()
                cell_probs += probs

                processed_count += len(probs)

            cell_probs = np.array(cell_probs).reshape(processed_count, -1)

            if aug_probs is None:
                aug_probs = cell_probs / len(augment_list)
            else:
                aug_probs += cell_probs / len(augment_list)

        all_probs.append(aug_probs)

        if generate_meta:
            masks = np.array(masks)

            if masks.shape[0] > 0:
                # Generate RLE string for each cell mask
                # https://www.kaggle.com/dschettler8845/hpa-cellwise-classification-inference/notebook?scriptVersionId=55714434
                submit_strings = []
                for i in range(masks.shape[0]):
                    mask = masks[i]
                    mask = rle_to_mask(mask, crop_size, crop_size)
                    # Important: set 255 to 1
                    mask[mask > 0] = 1

                    # Important: resize to orignal resolution to submit correct mask RLE string
                    # https://www.kaggle.com/linshokaku/faster-hpa-cell-segmentation/comments#1251082
                    mask = cv2.resize(mask, (width, height),
                                      interpolation=cv2.INTER_NEAREST)

                    rle_string = binary_mask_to_ascii(mask, mask_val=1)
                    submit_strings.append(rle_string)

                if len(submit_strings) > 0:
                    all_meta[id] = submit_strings
                else:
                    all_meta[id] = []
            else:
                all_meta[id] = []

        del images, masks, aug_probs, cell_probs, batch_images
        gc.collect()

        if debug and global_processed == batch_size - 1:
            break

        global_processed += 1

    del model
    torch.cuda.empty_cache()
    gc.collect()

    if generate_meta:
        return all_probs, all_meta
    else:
        return all_probs

In [26]:
%%time

all_probs = []
all_meta = {}

augment_list = [
    'default', 'flipud', 'fliplr', 'transpose', 'flipud_lr',
    'flipud_transpose', 'fliplr_transpose', 'flipud_lr_transpose'
]
seeds = [0, 1, 2, 3]
fold = 0

cache_size = batch_size if debug else 100
global_cache = {}
batch_rounds = 0
for batch_i in range(0, submit_df.shape[0], cache_size):
    print(f"[Batch Processing {batch_rounds}]")
    sub_df = submit_df.iloc[batch_i:batch_i + cache_size, :].copy()
    print(sub_df.shape)

    batch_probs = [0] * sub_df.shape[0]
    for i, s in enumerate(seeds):
        print(f"Inferencing with seed {rand_seed+1000*s} ......")

        if i == 0:
            seed_probs, meta = predict(sub_df,
                                       s,
                                       fold=fold,
                                       generate_meta=True)
            print(len(seed_probs), len(batch_probs), sub_df.shape[0])
            for j in range(len(seed_probs)):
                batch_probs[j] = seed_probs[j] / len(seeds)
            all_meta.update(meta)
        else:
            seed_probs = predict(sub_df, s, fold=fold)
            print(len(seed_probs), len(batch_probs), sub_df.shape[0])
            for j in range(len(seed_probs)):
                batch_probs[j] += seed_probs[j] / len(seeds)

    if batch_probs is not None:
        all_probs.extend(batch_probs)

    if debug:
        break

    # Reset cache
    del global_cache
    gc.collect()
    global_cache = {}
    batch_rounds += 1

[Batch Processing 0]
(4, 4)
Inferencing with seed 1120 ......
>> Using pre-trained model.


  0%|          | 0/4 [00:00<?, ?it/s]

Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!


 25%|██▌       | 1/4 [00:02<00:07,  2.36s/it]

Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!


 50%|█████     | 2/4 [00:04<00:04,  2.24s/it]

Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!


 75%|███████▌  | 3/4 [00:08<00:02,  2.92s/it]

Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!


 75%|███████▌  | 3/4 [00:11<00:03,  3.89s/it]


4 4 4
Inferencing with seed 2120 ......
>> Using pre-trained model.


  0%|          | 0/4 [00:00<?, ?it/s]

Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!


 25%|██▌       | 1/4 [00:01<00:05,  1.81s/it]

Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!


 50%|█████     | 2/4 [00:03<00:03,  1.73s/it]

Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!


 75%|███████▌  | 3/4 [00:07<00:02,  2.36s/it]

Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!


 75%|███████▌  | 3/4 [00:09<00:03,  3.18s/it]


4 4 4
Inferencing with seed 3120 ......
>> Using pre-trained model.


  0%|          | 0/4 [00:00<?, ?it/s]

Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!


 25%|██▌       | 1/4 [00:01<00:05,  1.88s/it]

Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!


 50%|█████     | 2/4 [00:03<00:03,  1.78s/it]

Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!


 75%|███████▌  | 3/4 [00:07<00:02,  2.37s/it]

Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!


 75%|███████▌  | 3/4 [00:09<00:03,  3.19s/it]


4 4 4
Inferencing with seed 4120 ......
>> Using pre-trained model.


  0%|          | 0/4 [00:00<?, ?it/s]

Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!
Cache hit for 0040581b-f1f2-4fbe-b043-b6bfea5404bb!


 25%|██▌       | 1/4 [00:01<00:05,  1.87s/it]

Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!
Cache hit for 004a270d-34a2-4d60-bbe4-365fca868193!


 50%|█████     | 2/4 [00:03<00:03,  1.77s/it]

Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!
Cache hit for 00537262-883c-4b37-a3a1-a4931b6faea5!


 75%|███████▌  | 3/4 [00:07<00:02,  2.38s/it]

Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!
Cache hit for 00c9a1c9-2f06-476f-8b0d-6d01032874a2!


 75%|███████▌  | 3/4 [00:09<00:03,  3.18s/it]

4 4 4
CPU times: user 2min 26s, sys: 5.93 s, total: 2min 32s
Wall time: 44.2 s





In [27]:
# all_probs = np.concatenate(all_probs, axis=0)
# all_probs.shape

In [28]:
len(all_probs)

4

In [29]:
len(all_meta)

4

In [30]:
# augment_list = [
#     'default', 'flipud', 'fliplr', 'transpose', 'flipud_lr',
#     'flipud_transpose', 'fliplr_transpose', 'flipud_lr_transpose'
# ]
# seeds = [0, 1, 2, 3]
# for seed in seeds:
#     seed_everything(seed)

#     for augment in augment_list:
#         transform = eval(f"augment_{augment}")
#         if args.crop_size > 0:
#             sub_submit_out_dir = opj(submit_out_dir,
#                                      '%s_seed%d' % (augment, seed))
#         else:
#             sub_submit_out_dir = opj(submit_out_dir, augment)
#         if not ope(sub_submit_out_dir):
#             os.makedirs(sub_submit_out_dir)
#         with torch.no_grad():
#             predict(test_loader, model, sub_submit_out_dir, dataset)

In [36]:
for p in all_probs:
    print(p.shape)

(15, 28)
(12, 28)
(36, 28)
(22, 28)


In [41]:
all_predictions = []
for index, row in tqdm(submit_df.iterrows(), total=submit_df.shape[0]):
    id = row["ID"]
    width = row["ImageWidth"]
    height = row["ImageHeight"]

    cell_probs = all_probs[index]
    rle_strings = all_meta[id]

    new_preds = np.zeros((cell_probs.shape[0], 19))
    for i in range(cell_probs.shape[0]):
        for j in range(28):
            new_class_i = old_class_mappings[j]
            # Take maximum prob.
            if cell_probs[i, j] > new_preds[i, new_class_i]:
                new_preds[i, new_class_i] = cell_probs[i, j]

    submit_strings = []
    for i in range(new_preds.shape[0]):
        confidence = new_preds[i, ...]
        rle_string = rle_strings[i]
        for l in range(19):
            submit_strings.append(f"{l} {confidence[l]:.6f} {rle_string}")

    if len(submit_strings) > 0:
        all_predictions.append(" ".join(submit_strings))
    else:
        all_predictions.append("")

    if debug and index == batch_size - 1:
        break

  1%|          | 3/559 [00:00<00:00, 792.28it/s]


In [42]:
len(all_predictions)

4

In [43]:
all_predictions

['0 0.376270 eNq1lcFuwyAMhl8JL3TKoYceekhXRNGEWjoxCalsTVWr738cmZJsIcUh6pYLn7F/Q6zENgsP9gbs98OZaKFb3atEiVfQoHWCFAI+rQe6OE+X7gcYuZ98BL8bPpK3cb2/4rQ/dXzKwQWjLjKuY97a62a+/+z37fyC/2uekT+KEyzhb9fY38sn4vXGMf0uV+oEjqkUIPNWqp07gQJHUeHeJAr/AbqwBLkCj9IJ8xmoJggLc5RaYEOaIMXVUeJWn0FzNGmy3Bvptu4MjjuCPDdG2i02pElSjULVQUuR5/UhKGxDnqCam0NQ1I2WonBuo9CXQLgnaEF/lX9uPxqX++R2wdzulKuLbXyxe4mFv4AWSXJXsLv7bTzOJ0pxtz6j/dYe1Tne54LOE+/zibxR/mo50a4fLfhEe8/2535ok/OSzbuf2oQ/3zyb8G+nyV5gnXlc3N7j8RDZVTm042nQ+efWqawS5ao4PfWrvIYyNb5yx2x/j2o4Zperob9cD/2x3ce3ds1QLr/J4OEGX7QUS28= 1 0.104639 eNq1lcFuwyAMhl8JL3TKoYceekhXRNGEWjoxCalsTVWr738cmZJsIcUh6pYLn7F/Q6zENgsP9gbs98OZaKFb3atEiVfQoHWCFAI+rQe6OE+X7gcYuZ98BL8bPpK3cb2/4rQ/dXzKwQWjLjKuY97a62a+/+z37fyC/2uekT+KEyzhb9fY38sn4vXGMf0uV+oEjqkUIPNWqp07gQJHUeHeJAr/AbqwBLkCj9IJ8xmoJggLc5RaYEOaIMXVUeJWn0FzNGmy3Bvptu4MjjuCPDdG2i02pElSjULVQUuR5/UhKGxDnqCam0NQ1I2WonBuo9CXQLgnaEF/lX9uPxqX++R2wdzulKuLbXyxe4mFv4AWSXJXsLv7bTzOJ0pxtz6j/dYe1Tne54LOE+/zibxR/mo50a4fLfhEe8/2535ok/OSzbuf2oQ/3zyb8G+nyV5

In [44]:
if debug:
    submit_df.iloc[:batch_size, :]["PredictionString"] = all_predictions
else:
    submit_df["PredictionString"] = all_predictions
submit_df

Unnamed: 0,ID,ImageWidth,ImageHeight,PredictionString
0,0040581b-f1f2-4fbe-b043-b6bfea5404bb,2048,2048,0 0.376270 eNq1lcFuwyAMhl8JL3TKoYceekhXRNGEWjo...
1,004a270d-34a2-4d60-bbe4-365fca868193,2048,2048,0 0.400621 eNoLCMhIMAgxzMkzNACDBAMHCCPFIMYfwrL...
2,00537262-883c-4b37-a3a1-a4931b6faea5,2048,2048,0 0.282796 eNqFUC0PQyEM/Eslq6hAIBCIZjyJIBkCQWb...
3,00c9a1c9-2f06-476f-8b0d-6d01032874a2,2048,2048,0 0.323579 eNrtVLsKwzAM/CULBPGQIaMG46UmyZDBgwe...
4,0173029a-161d-40ef-af28-2342915b22fb,3072,3072,0 1 eNoLCAgIsAQABJ4Beg==
...,...,...,...,...
554,fea47298-266a-4cf4-93bd-55d1bcc2fc7d,1728,1728,0 1 eNoLCAjJNgIABNkBkg==
555,feb955db-6c07-4717-a98b-92236c8e01d8,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
556,fefb9bb7-934a-40d1-8d2f-210265857388,2048,2048,0 1 eNoLCAgIMAEABJkBdQ==
557,ff069fa2-d948-408e-91b3-034cfea428d1,3072,3072,0 1 eNoLCAgIsAQABJ4Beg==


In [45]:
submit_df.to_csv("submission.csv", index=False)

In [49]:
def prob_to_result(probs, img_ids, th=0.5):
    probs = np.concatenate(probs, axis=0)
    predicted_probs = probs.copy()
    probs[np.arange(len(probs)), np.argmax(probs, axis=1)] = 1

    pred_list = []
    pred_list_new = []
    for line in probs:
        # Map old classes to new ones
        predicted_old_classes = sorted(
            list(set([i for i in np.nonzero(line > th)[0]])))
        predicted_new_classes = sorted(
            list(set([old_class_mappings[i]
                      for i in np.nonzero(line > th)[0]])))
        # print(predicted_classes)
        s = '|'.join([str(i) for i in predicted_old_classes])
        s_new = '|'.join([str(i) for i in predicted_new_classes])
        pred_list.append(s)
        pred_list_new.append(s_new)
    result_df = pd.DataFrame({
        # "ID": img_ids,
        "Predicted": pred_list,
        "Predicted_New": pred_list_new
    })
    return result_df

In [50]:
result_df = prob_to_result(all_probs, None, th=confidence_threshold)
result_df.to_csv("result_comparison.csv", index=False)
result_df.head()

Unnamed: 0,Predicted,Predicted_New
0,25,16
1,25,16
2,0,0
3,25,16
4,7,7


In [51]:
result_df.head(100)

Unnamed: 0,Predicted,Predicted_New
0,25,16
1,25,16
2,0,0
3,25,16
4,7,7
5,25,16
6,0,0
7,25,16
8,25,16
9,25,16


In [52]:
result_df["Predicted_New"].value_counts()

14    32
16    25
7     18
0      6
18     2
15     1
8      1
Name: Predicted_New, dtype: int64

In [53]:
# test_dataset.release_gpu()
# del model, test_dataset, test_loader
torch.cuda.empty_cache()
gc.collect()

1525

In [54]:
# !rm densenet121-a639ec97.pth
# !rm -rf inference test_cell_masks
# !ls -la

### EOF