In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random

In [None]:
raw_dir = Path("raw")
raws = [raw_path for raw_path in raw_dir.ls() if ".tif" in raw_path.as_posix()]
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
processed_dir = Path("processed")

l=224

In [None]:
random.seed(23)
cutoff=1
empty = 0
R_popu = 0
popu = 0

for image_path,label_path in zip(images,labels):
    image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
    label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)

    if image.shape != label.shape:
        raise ValueError(image_path.as_posix() + label_path.as_posix())
    i_max = image.shape[0]//l
    j_max = image.shape[1]//l

# If the cells were labelled as 255, or something else mistakenly, instead of 1.
    label[label!=0]=1

    for i in range(i_max):
        for j in range(j_max):
            cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
            cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]
            
            if ("_R_" in image_path.as_posix() and (cropped_label!=0).any()):
                R_popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            elif "_R_" in image_path.as_posix():
                continue
            elif (cropped_label!=0).any():
                popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            elif random.random() >= cutoff:
                empty+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)
            else:
                continue
            cv.imwrite(cropped_image_path.as_posix(), cropped_image)
            cv.imwrite(cropped_label_path.as_posix(), cropped_label)

In [None]:
print(R_popu)
print(popu)
print(empty)

## Train NN

In [None]:
torch.cuda.set_device(0)

In [None]:
bs = 16
#bs=16 and l=224 will use ~7300MiB for resnet34  before unfreezing
#bs=4 and l=224 use ~12145MiB for resnet50 before unfreezing

In [None]:
transforms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_zoom = 1, #consider
    max_rotate = 0,
    max_lighting = None,
    max_warp = None,
    p_affine = 0.75,
    p_lighting = 0.75)

In [None]:
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())
codes = ["NOT-CELL", "CELL"]
def filter_empties(fname, cutoff=0.95):
    if "empty" in Path(fname).name:
        # Return the next random floating point number in the range [0.0, 1.0).
        return (random.random() > cutoff)
    else:
        return True

src = (
    SegmentationItemList.from_folder(processed_dir)
    .filter_by_func(lambda fname:'image' in Path(fname).name)
#     .filter_by_func(filter_empties)
    .split_by_rand_pct(valid_pct=0.20, seed=2)
    .label_from_func(get_label_from_image, classes=codes)
)
data = (
    src.transform(transforms, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)

In [None]:
learn = unet_learner(data, models.resnet34, metrics=partial(dice, iou=True))
# learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([1,1]).cuda())

In [None]:
lr_find(learn)
learn.recorder.plot(suggestion=True)

In [None]:
lr = 2.5e-3
learn.fit_one_cycle(30, lr)

In [None]:
learn.recorder.plot_losses()

In [None]:
models_path = Path("../../models")
learn.save(models_path/"2019-07-24_RESNET34_IOU0.66_stage1")

In [None]:
learn.load(models_path/"2019-07-24_RESNET34_IOU0.66_stage1");

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(lr/400,lr/4)

In [None]:
learn.fit_one_cycle(12, lrs, pct_start=0.8)
# learn.fit_one_cycle(12, lrs)

In [None]:
learn.save(models_path/"2019-07-24_RESNET34_IOU0.69_stage2")

In [None]:
learn.export(file = "../models/2019-07-24_RESNET34_IOU0.69_stage2.pkl")

## Check

In [None]:
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])

In [None]:
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()

In [None]:
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
#     print(torch.max(preds[0][i][1]))

# Image(preds[1][0]).show()

In [None]:
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
    N = learn.data.valid_ds.__len__()
else:
    raise ValueError()

xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]

In [None]:
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))
for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.7)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.7)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
# plt.savefig('/hpf/largeprojects/MICe/nwang/TissueVision/2019-05-31_Mallar_NeuralNet/figures/2019-06-12_mallar-results.png')
plt.show()

In [None]:
#Wow resnet34 learns the same thing as resnet 50... I wish I did this earlier
!jupyter nbconvert arc-venus-train.ipynb --to html --output nbs/2019-07-24_RESNET34_IOU0.69_stage2.html