### Official, more updated code is in Kaggle

In [1]:
import base64
import numpy as np
# from pycocotools import _mask as coco_mask
import typing as t
import zlib
import torch
import shutil
import os
import pandas as pd
import tifffile as tiff
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
from collections import defaultdict
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from IPython.display import Image as show_image
from torch.utils.data import DataLoader
from skimage.measure import regionprops_table, label, regionprops
from pycocotools import _mask as coco_mask
import gc
import warnings
import torchvision.transforms as T
import segmentation_models_pytorch as smp
from PIL import Image
import torch.nn as nn
import random
import cv2
warnings.filterwarnings("ignore")

In [121]:
# model configs, note that key is for the different type of resnet50's available for use from lunit
class model_config:
    seed = 42
    key = "MoCoV2"
    train_batch_size = 8
    valid_batch_size = 8
    epochs = 5
    CV_fold = 5
    learning_rate = 0.001
    scheduler = "CosineAnnealingLR"
    T_max = int(30000 / train_batch_size * epochs)  # for cosineannealingLR, explore different values
    weight_decay = 1e-6  # explore different weight decay (Adam optimizer)
    n_accumulate = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    iters_to_accumulate = max(1, 32 // train_batch_size)  # for scaling accumulated gradients
    eta_min = 1e-5
    dice_alpha = 0.5
    bce_alpha = 0.5
    binary_threshold = 0.3
    model_save_directory = os.path.join(os.getcwd(), "model",
                                        str(key))  #assuming os.getcwd is the current training script directory

In [3]:
# sets the seed of the entire notebook so results are the same every time we run for reproducibility. no randomness, everything is controlled.
def set_seed(seed=42):
    np.random.seed(seed)  #numpy specific random
    random.seed(seed)  # python specific random (also for albumentation augmentations)
    torch.manual_seed(seed)  # torch specific random
    torch.cuda.manual_seed(seed)  # cuda specific random
    # when running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # when deterministic = true, benchmark = False, otherwise might not be deterministic
    os.environ['PYTHONHASHSEED'] = str(seed)  # set a fixed value for the hash seed, for hases like dictionary
set_seed(model_config.seed)

In [191]:
# configurations:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
debug = True
model_paths = glob(r'C:\Users\Kevin\PycharmProjects\hubmap\unet++_resnet50\model\MoCoV2\best_epoch*.pt') #ensembles
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if debug:
    test_paths = glob(r'\\fatherserverdw\Kevin\hubmap\unet++\images\*.tif')[0:100] #load all train and debug or do just first hundreds
else:
    test_paths = glob(r'\\fatherserverdw\Kevin\hubmap\test*.tif')

In [192]:
model_paths = model_paths[2:4]
model_paths

['C:\\Users\\Kevin\\PycharmProjects\\hubmap\\unet++_resnet50\\model\\MoCoV2\\best_epoch-02.pt',
 'C:\\Users\\Kevin\\PycharmProjects\\hubmap\\unet++_resnet50\\model\\MoCoV2\\best_epoch-03.pt']

In [182]:
test_paths

['\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0a993633aa5e.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0a43459733e7.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0ab9d193fcf6.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0acd70e887b3.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0ae9282b7594.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0b89ab7f9f07.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0b935dd9ef6a.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0b989fe8238f.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0b8029db1fb4.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0ba172f33ea6.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0bd23d24a875.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0be9b14718b9.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0c5c322a104a.tif',
 '\\\\fatherserverdw\\Kevin\\hubmap\\unet++\\images\\0c3086bd8ef

In [166]:
def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != bool:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)

  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str

In [31]:
class HubmapDataset(torch.utils.data.Dataset):
    def __init__(self, imgs, transforms):
        self.transforms = transforms
        self.imgs = imgs
        self.name_indices = [os.path.splitext(os.path.basename(i))[0] for i in imgs]

    def __getitem__(self, idx):
        # load images and masks
        img_path = self.imgs[idx]
        name = self.name_indices[idx]
        img = tiff.imread(img_path)
        img = Image.fromarray(img)
        img = self.transforms(img)
        return img, name

    def __len__(self):
        return len(self.imgs)

In [175]:
def remove_elements_below_threshold(arr, threshold):
    return [x for x in arr if x >= threshold]

In [176]:
@torch.no_grad()
def process_pred(prob_prediction):
    binary_prediction = prob_prediction > model_config.binary_threshold #(probability threshold)
    binary_prediction = binary_prediction.cpu().numpy()
    num_classes = len(np.unique(binary_prediction))
    if num_classes == 1: #empty predictions:
        label_img = np.zeros((512,512),dtype=np.uint8)
        confidences = []
    else:
        label_img = label(binary_prediction.astype('bool'))
        tb = regionprops_table(label_img, properties=['bbox', 'coords'])
        tt = pd.DataFrame(tb)
        bboxes = tt[['bbox-1', 'bbox-0', 'bbox-3', 'bbox-2']].values #min_x,min_y,max_x,max_y
        confidences = []
        for bbox in bboxes:
            min_x = bbox[0]
            min_y = bbox[1]
            max_x = bbox[2]
            max_y = bbox[3]
            label_img_probs = prob_prediction[min_y:max_y,min_x:max_x]
            confidence = label_img_probs.cpu().numpy().flatten()
            confidence = remove_elements_below_threshold(confidence,threshold=0.001)
            confidence = np.mean(confidence)
            confidences.append(confidence)
    return label_img, confidences

In [193]:
test_transforms = T.Compose([T.PILToTensor(),T.ConvertImageDtype(torch.float32), T.Normalize(mean=[0.6801, 0.4165, 0.6313], std=[0.1308, 0.2094, 0.1504])]) # Size C x H x W tensor with float dtype
# maybe try including normalization to images for test_transform

dataset_test = HubmapDataset(test_paths, transforms = test_transforms)
test_dl = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

In [194]:
def build_model():
    model = smp.UnetPlusPlus(encoder_name="resnet50", encoder_weights=None, activation=None,
                             in_channels=3, classes=1, decoder_use_batchnorm=True)
    model.to(device)  # model to gpu
    return model

In [209]:
model = build_model()
id_list, heights, widths, prediction_strings = [],[],[],[]
with torch.no_grad():
    for model_path in model_paths:
        for img, idx in tqdm(test_dl):
            model.load_state_dict(torch.load(model_path))
            model.eval() #eval stage
            img = img.to(device)
            prediction = model(img)
            prob_prediction = nn.Sigmoid()(prediction)
            prob_prediction = torch.squeeze(prob_prediction)
            # print("mean is: {}".format(torch.mean(prob_prediction)))
            # print("max is: {}".format(torch.max(prob_prediction)))
            # print("min is: {}".format(torch.min(prob_prediction)))
            label_img, confidences = process_pred(prob_prediction)
            binmasks = []
            for idx1 in range(len(confidences)):
                binmask = label_img == idx1
                binmasks.append(binmask)

            pred_string = ""
            for idx2, (binmask, confidence) in enumerate(zip(binmasks, confidences)):
                encoded = encode_binary_mask(binmask)
                if idx2 == 0: #beginning, no space
                    pred_string += f"0 {confidence:0.4f} {encoded.decode('utf-8')}"
                else:
                    pred_string += f" 0 {confidence:0.4f} {encoded.decode('utf-8')}"
            h = img.size()[2]
            w = img.size()[3]
            id_list.append(idx[0])
            heights.append(h)
            widths.append(w)
            prediction_strings.append(pred_string)


 28%|██▊       | 28/100 [00:15<00:39,  1.80it/s]


KeyboardInterrupt: 

In [196]:
submission = pd.DataFrame()
submission['id'] = id_list
submission['height'] = heights
submission['width'] = widths
submission['prediction_string'] = prediction_strings
# submission.to_csv("submission.csv",index=False)
print(submission)

               id  height  width  \
0    0a993633aa5e     512    512   
1    0a43459733e7     512    512   
2    0ab9d193fcf6     512    512   
3    0acd70e887b3     512    512   
4    0ae9282b7594     512    512   
..            ...     ...    ...   
195  2e0c92f0c9df     512    512   
196  2e02a3e00059     512    512   
197  2e3a658a8c8e     512    512   
198  2e51dff130b7     512    512   
199  2e7951162645     512    512   

                                     prediction_string  
0    0 0.3550 eNodjskOwjAMRH/JztJF4opIU8c+sRZEWwptJ...  
1                                                       
2    0 0.7958 eNozSApMNzSNiIv0M/I38Dc0AGF/I7+snDBTv...  
3                                                       
4                                                       
..                                                 ...  
195  0 0.9126 eNpFUG13sjAM/UsJBR/HDr6wMUehBB/FiS/Tg...  
196  0 0.9144 eNoljEsKgDAMBa/0mtbfBQTRJF2IICqIUsGV9...  
197  0 0.9009 eNpFjM0KAjEMhF8pP93uQVnRgwvdJnEvCgrSg...