In [None]:
# !pip install tensorflow==2.5.0

In [None]:
import matplotlib
matplotlib.use('Agg')
%matplotlib inline

In [None]:
import glob

SEED = 13
MAX_VAL = 64
VAL_IMAGES_GLOB = '/home/vvasin/ikutukov/bd/test/**/*.jpeg'

MODELS_GLOB = '/home/vvasin/ikutukov/ARU-Net/models/**/model-20.pb'
OUTPUT = '/home/vvasin/ikutukov/bd-val/'

In [None]:
import glob
import json
import os
import re
import sys
import traceback

import tqdm.notebook as tqdm
import click
from click.testing import CliRunner
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
sys.path.append('/home/vvasin/ikutukov/ARU-Net/') # go to parent dir

from pix_lab.main.validate import validate

def path2mn(model_path):
    model_name_tok = re.findall('[\\\/]([^\\\/]+)[\\\/][^\\\/]+$', model_path)
    if model_name_tok:
        return model_name_tok[0]
    else:
        return None

def val(model_path):
    runner = CliRunner()
    result = runner.invoke(
        validate,
        "--path_to_pb {} --input '{}' --max_val {} --seed {} --output {}".format(
            model_path, 
            VAL_IMAGES_GLOB,
            MAX_VAL,
            SEED,
            os.path.join(OUTPUT, path2mn(model_path),
        ),                                                                                   ),
        catch_exceptions=True
    )
    if result.exit_code != 0:
        print('STDERR: {}'.format(result.stderr_bytes))
        traceback.print_tb(result.exc_info[2])
    else:
        pass
    
    res = {
        (kv[2] if len(kv) > 2 else kv[0]): float(kv[1])
        for kv in [
            str(s).split('\t')
            for s 
            in str(result.stdout).split('\n')
        ] if len(kv) > 1
    }
    return res

print('Images: {}, Models: {}'.format(
    len(glob.glob(VAL_IMAGES_GLOB)),
    len(glob.glob(MODELS_GLOB))
))

performance = dict()
model_paths = glob.glob(MODELS_GLOB)
for model_path in tqdm.tqdm(model_paths, unit="model"):
    model_name = path2mn(model_path)
    if model_name:
        performance[model_name] = val(model_path)
        # print('Performance of {}:'.format(model_name), performance[model_name])
# print(performance)


In [None]:
import pandas
reports = []
for report_path in sorted(glob.glob(os.path.join(OUTPUT, '**', 'report.tsv'))):
    reports.append(pandas.read_csv(report_path, sep='\t', header=0))
report = pandas.concat(reports).sort_values(['model', 'loss', 'filename'])

report.groupby(['model']).mean().sort_values(['loss'])['loss']


In [None]:
top_confusion = report.groupby(['filename']).mean().sort_values(['loss'], ascending=False)[0:10]['loss'].to_dict()
top_confusion

In [None]:
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

for k, v in top_confusion.items():
    per_model = glob.glob(os.path.join(OUTPUT, '**', os.path.basename(k)))
    plt.figure(figsize=(len(per_model) * 20, 30))
    loss_by_model = {
        os.path.basename(os.path.dirname(g)): v
        for g, v in report.loc[report['filename'] == k].groupby('model').mean().sort_values(['model'])['loss'].to_dict().items()
    }
    for idx, g in enumerate(per_model, start=1):
        model_name = os.path.basename(os.path.dirname(g))
        ax1 = plt.subplot(1, len(per_model), idx)
        plt.title('{}\n{}'.format(model_name, loss_by_model[model_name]), fontsize=48)
        im = Image.open(g)        
        n, e = os.path.splitext(g)
        infer_path = '{}_2{}'.format(n,e)
        im_infer = Image.open(infer_path)
        ax1.imshow(im)
        ax1.imshow(ImageOps.colorize(im_infer, black='pink', white='black'), alpha=0.5)
        
    plt.show()

In [None]:
DEFAULT_THRESHOLD = 54


In [None]:
import os
import random
import numpy
import numpy as np

from PIL import Image

from skimage.morphology import skeletonize, thin
from skimage.transform import rescale #, resize, downscale_local_mean

from skimage import measure
from skimage import color
from skimage import feature
from skimage import filters
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


def imload(p, m='L', slice_size=2048):
    im = numpy.asarray(Image.open(p).convert(m))
    if slice_size is not None:
        return im[-slice_size:,-slice_size:,...]
    else:
        return im

    
fig = plt.figure(figsize=(
    24 * IMAGE_SET_SIZE, 
    32 * (len(PATHS) // IMAGE_SET_SIZE),
))

spec = gridspec.GridSpec(
    ncols=IMAGE_SET_SIZE, 
    nrows=len(PATHS) // IMAGE_SET_SIZE, 
    wspace=0.01, 
    hspace=0.01,
    figure=fig
)



for idx, imp in enumerate(PATHS, start=0):
    imp_n, imp_e = os.path.splitext(imp)
    imp0 = imp_n + '-0' + imp_e
    imp1 = imp_n + '-1' + imp_e
    
    print('Loading src and masks for ' + imp + ', ' + imp0 +', ' + imp1)

    im_scan = imload(imp)
    im0_input = imload(imp0)
    im1_input = imload(imp1)
    imp_n, imp_e = os.path.splitext(imp)
    
    im0_thresh = im0_input[:,:] > DEFAULT_THRESHOLD
    im0 = thin(im0_thresh.astype(np.uint8))
    im0_labeled = measure.label(im0)

    im1_thresh = im1_input[:,:] > DEFAULT_THRESHOLD
    im1 = im1_thresh.astype(np.uint8)
    im1_labeled = measure.label(im1)
    
    # label_inc = np.max(im0_labeled) + 1
    # im1_labeled[im1_labeled > 0] += label_inc
    
    im_color_labels = color.label2rgb(im0_labeled + im1_labeled, im_scan, alpha=0.7, bg_label=0)
    print('Plotting result to output')
    ax = fig.add_subplot(
        spec[idx // IMAGE_SET_SIZE, idx % IMAGE_SET_SIZE], 
        frameon=False, 
        title=imp
    )
    ax.set(xticks=[], yticks=[])
    ax.imshow(im_color_labels)
    print('Done!')

plt.show()


In [None]:
###### import matplotlib.pyplot as plt
# # rects = []
# # polys = trace_skeleton.traceSkeleton(
# #     im1,
# #     0,
# #     0,
# #     im1.shape[1],
# #     im1.shape[0],
# #     10,
# #     999,
# #     rects
# # )

# for l in polys:
#   c = (
#       200 * random.random(),
#       200 * random.random(),
#       200 * random.random()
#   )
#   for i in range(
#       0, 
#       len(l) - 1
#   ):
#     pass
#     cv2.line(
#         im_scan,
#         (
#             l[i][0],
#             l[i][1]
#         ),
#         (
#             l[i+1][0],
#             l[i+1][1]
#         ),
#         c
#     )
# plt.figure(figsize=(32, 32))
# plt.imshow(im_scan)
# plt.show()