In [None]:
from __future__ import print_function, division
%matplotlib notebook
import matplotlib.pyplot as plt

import re
from copy import deepcopy
import numpy as np
import cPickle as pickle
import re
from os import path
import csv
import pandas as pd
from time import strftime
from scipy.ndimage.filters import gaussian_filter as gauss

import warnings
with warnings.catch_warnings():
    from tqdm import TqdmExperimentalWarning
    warnings.simplefilter("ignore", category=TqdmExperimentalWarning)
    from tqdm.auto import tqdm, trange
warnings.filterwarnings("ignore", category=RuntimeWarning, message="numpy.dtype size changed", module="h5py")
from caspyr.utils import H5Reader, Globber, glob

from viper.plot_style import toLab, plt_kwargs, savefig, plt_title, font_prop
#from viper.utils.stats import movingAvg, nrmse, biasStd, biasStdMask
#import viper.utils.stats as vus; reload(vus); nrmse=vus.nrmse; biasStd = vus.biasStd  # , biasStdMask
from viper.utils.stats import nrmse, biasStd, biasStdMask
from viper.utils.mlab import imguidedfilter
from viper.constants import SIGMA2FWHM_MMR
#SIGMA2FWHM_MMR = (8 * np.log(2)) ** .5 * 2.08626
from viper.imsample import nonLocMn

In [None]:
fntScl = 1.5
def imshow(im, cmap="Greys", origin="lower", title=None, **kwargs):
    ax = plt.imshow(im, cmap=cmap, origin=origin, **kwargs)
    plt.axis("off")
    if title:
        plt.title(title)

def autoROI(im):
    """return bounding box tuple(slice) of non-zero elements"""
    inds = np.array(im.nonzero())
    mins = inds.min(axis=1)
    maxs = inds.max(axis=1)
    return tuple(slice(i, j) for i, j in zip(mins, maxs))

def imNorm(im, assumeZeroMin=True):
    """in-place, average 1 count/voxel"""
    if not assumeZeroMin:
        im -= im.min()
    im *= im.size / im.max()
    return im

def flatMask(im, thresh=50, hw=1):
    """
    thresh  : int, percentile
    hw  : int, half width
    """
    from scipy.signal import convolve
    m = np.ones([2*hw+1] * im.ndim) * -1
    m[hw, hw] = m.size - 1
    edges = np.abs(convolve(im, m, mode='same'))
    msk = edges < np.percentile(edges.flat, thresh)  # high pass
    msk *= im > np.percentile(im[im > 0].flat, 100 - thresh)  # low pass
    return msk

def isinstanceStr(s, cls):
    """@return isinstance(cls(s), cls) except: False"""
    try:
        return isinstance(cls(s), cls)
    except:
        return False

In [None]:
globber = Globber(
    ("meta0", "output/[0-9]+/brainweb_", "output/*/brainweb_"),
    ("rm", "PET|PETpsf", "PET*"),
    ("meta1", "_[0-9]+_subject_", "_*_subject_"),
    ("pat", "[0-9]+", "*"),
    ("meta6", "-S_", "-S_"),
    ("sigma", "[0-9.e+]+", "*"),
    ("meta7", "-NP_", "-NP_"),
    ("fsPET", "[0-9.e+]+", "*"),
    ("meta8", "-NT1_", "-NT1_"),
    ("fsT1", "[0-9.e+]+", "*"),
    ("meta2", "-C_", "-C_"),
    ("counts", "-?[0-9.e+]+", "*"),
    ("meta3", "_t", "_t"),
    ("tum", "-?[0-9]+", "*"),
    ("meta4", "_", "_"),
    ("it", "[0-9]+", "*"),
    ("meta5", r"\.mat", ".mat"))

In [None]:
maskThresh = 50
PAD = 3
from viper.classifier.data.bw import ROI as ROI_BW  # zxy

In [None]:
def getGnd(fn, debug=False):
    # ground truth image
    counts = globber.split(fn)[globber["counts"]]\
        .replace('-', "-?").replace('+', '\\+').replace('.', '\\.')
    files = globber.filterReUnlike(fn, it="000") or globber.filterReUnlike(
        fn, meta0=None, meta1=None, rm=None, it="000", counts=counts)  # counts=None
    if not files:
        raise IOError("Could not find truth for:%s" % fn)
    if debug:
        #ROI = autoROI(H5Reader(files[0]).tAct[:])  # non-zero ROI
        ROI = ROI_BW
        imGnd = H5Reader(files[0]).tAct[ROI]
        tmp = [(imGnd == H5Reader(i).tAct[ROI]).all() for i in tqdm(files, desc="debug")]
        #print(len(tmp), files)
        if not all(tmp):
            raise ValueError("mismatched ground truths:" + p)
    return files[0]

In [None]:
#globber.partsVals[:-2] + ['000 -> 300', globber.partsVals[-1]]
rms = globber.partsVals[globber["rm"]]
pats = globber.partsVals[globber["pat"]]  # int
counts = globber.partsVals[globber["counts"]]  # float
tums = globber.partsVals[globber["tum"]]  # int
its = globber.partsVals[globber["it"]]  # int
print(rms, pats, counts, tums)  # , its

In [None]:
pats = """
54
""".strip().split()
counts = ["%.3g" % (i * 1e6) for i in [4.3, 43, 301]]
tums=['1']
#brainweb_PET_*_54-*_4.3*_t1_001.mat

In [None]:
print(rms, pats, counts, tums)  # , its

In [None]:
varDict = {}
bsqDict = {}

In [None]:
#outputPkl = "output/biasVar.pkl"
#outputPkl = "output/biasVar-flat.pkl"
outputPkl = "output/biasVar-brainweb.pkl"

In [None]:
with open(outputPkl) as fd:
    bsqDict, varDict = pickle.load(fd)

In [None]:
#ROI = slice(0, None), slice(100, 250), slice(100, 250)
for rm in tqdm(rms, desc="RM"):
    for pat in tqdm(pats, desc="Patient"):
        for cnt in tqdm(counts, desc="Counts"):
            for tum in tqdm(tums, desc="Tumours"):
                files = globber.filterRe(
                    escape=True,
                    rm=rm, pat=pat, counts=cnt, tum=tum,
                    it="001")
                if not files:
                    print("cannot find globber files", rm, pat, cnt, tum)
                    continue

                # per-iteration results vectors
                bsq = []
                var = []
                # ground truth image`
                imGnd = H5Reader(getGnd(files[0]))
                #ROI = autoROI(imGnd.tAct[:])  # non-zero ROI
                #ROI = (slice(3, -3),)*3
                ROI = ROI_BW
                ROI = tuple(slice(i.start + 3, i.stop - 3) for i in ROI)
                ##half = (ROI[0].stop - ROI[0].start) // 2 + ROI[0].start
                ##ROI[0] = slice(half - 3, half + 3)
                imT1 = imGnd.T1[ROI]
                imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
                #imGnd = imNorm(imGnd)
                # imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
                ####mask = flatMask(imGnd, maskThresh)
                #mask = imGnd > 0
                imGndPreMask = imGnd
                #imT1PreMask = imT1
                #imGnd = imGnd[mask]
                #imT1 = imT1[mask]
                #imScale = (imGnd ** 2).mean()
                # pre-allocate one set of realisations [R, X, Y, Z]
                #ims = np.zeros((len(parts),) + imGnd.shape, dtype=imGnd.dtype)
                #ims = np.zeros((len(files), imGnd.size), dtype=imGnd.dtype)
                with tqdm(total=len(its),
                          desc="%d reals, %.3g scale" % (len(files), imGnd.sum())) as tIters:
                  it = 1
                  while True:
                    usedfiles = filter(path.exists, [p[:-7] + "%03d.mat" % it for p in files])
                    if not usedfiles:
                        if it > 1:
                            usedfiles = filter(path.exists, [p[:-7] + "%03d.mat" % (it - 1) for p in files])
                            last_it_ims = [H5Reader(p).Img[ROI] for p in usedfiles]
                        break
                    #if len(ims) != len(usedfiles):
                    #    log.warn("altering preallocation")
                    #    ims = np.zeros((len(usedfiles), imGnd.size), dtype=imGnd.dtype)
                    #for i, p in enumerate(tqdm(usedfiles, desc="realisations", disable=True)):
                    #    #ims[i] = imNorm(H5Reader(''.join(p)).Img[ROI])[mask]
                    #    ims[i] = H5Reader(p).Img[ROI][mask]
                    HACK_SKIP = False
                    if not HACK_SKIP:
                        last_it_ims = [H5Reader(p).Img[ROI] for p in usedfiles]
                        bias, std = biasStdMask(np.array(last_it_ims), imGndPreMask[None])
                        bsq.append((bias / 100) ** 2)
                        var.append((std / 100) ** 2)
                        # ims = [i[mask] for i in last_it_ims]
                        # imMean = np.mean(ims, axis=0)
                        # #imMean = ims.mean(axis=0)
                        # #bsq.append(((imMean - imGnd) / imGnd).mean())
                        # #var.append((np.std(ims, axis=0, ddof=1) / imMean).mean())
                        # bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                        # #var.append((np.var(ims, axis=0, ddof=min(1, len(ims)-1))).mean() / imScale)
                        # var.append((np.var(ims, axis=0)).mean() / imScale)
                        tIters.set_postfix(
                            # N.B: mseTest will be nonzero due to var(..., ddof=1)
                            # mseTest = ((ims - imGnd) ** 2).mean() / imScale - bsq[-1] - var[-1],
                            mse=bsq[-1] + var[-1], bsq=bsq[-1], var=var[-1], refresh=False)
                    tIters.update()
                    it += 1
                #key = list(globber.valRe.findall(files[0])[0])
                #key[globber["it"]] = ""
                #key = ':'.join(key)
                #key = ':'.join((rm, pat, cnt, tum))
                key = rm, pat, cnt, tum
                if not HACK_SKIP:
                    bsqDict[key] = np.array(bsq)
                    varDict[key] = np.array(var)
                if 0 < float(cnt) < 301e6:  # and "psf" in rm:
                    num = it - 1  # len(bsq)
                    last_it_ims = np.array(last_it_ims)
                    if True:  # post-smoothing
                        bsq = []
                        var = []
                        for s in tqdm(np.linspace(0, 25, num=num) / SIGMA2FWHM_MMR, desc="PS"):
                            ims = [gauss(im, s) for im in last_it_ims]
                            bias, std = biasStdMask(np.array(ims), imGndPreMask[None])
                            bsq.append((bias / 100) ** 2)
                            var.append((std / 100) ** 2)
                            # ims = [gauss(im, s)[mask] for im in last_it_ims]
                            # imMean = np.mean(ims, axis=0)
                            # bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                            # #var.append((np.var(ims, axis=0, ddof=min(1, len(ims)-1))).mean() / imScale)
                            # var.append((np.var(ims, axis=0)).mean() / imScale)
                            #if np.all(last_it_ims == ims):
                            #    print(s)
                        key1 = key[0], key[1] + "PS", key[2], key[3]
                        bsqDict[key1] = np.array(bsq)
                        varDict[key1] = np.array(var)
                    if True:  # guided-filtering
                        nLMRad = 2
                        half = len(mask) // 2
                        #ROI_GF = slice(half, half + 1)
                        ROI_GF = slice(half - nLMRad, half + nLMRad)
                        #ROI_GF = slice(0, None)
                        #maskGF = mask[ROI_GF]
                        imGndPreMaskGF = imGndPreMask[ROI_GF]
                        #imGndGF = imGndPreMaskGF[maskGF]
                        #imScaleGF = (imGndGF ** 2).mean()
                        bsq = []
                        var = []

                        from scipy.interpolate import interp1d
                        #eps = np.logspace(-13, 0, num=num)
                        eps = np.logspace(-5, 0, num=num)
                        #eps = np.linspace(1e-5, 1, num=num)
                        epsUndersamp = np.logspace(np.log10(eps[0]), np.log10(eps[-1]), num=num // 10)
                        #epsUndersamp = np.linspace(eps[0], eps[-1], num=10)
                        print(usedfiles[-1])
                        with tqdm(epsUndersamp, desc="GF") as prog:
                          for s in prog:
                            ims = [#imguidedfilter(im[ROI_GF], imT1[ROI_GF],
                                   #               float(s), 3, progress='disable')[maskGF]
                                   nonLocMn(im[ROI_GF], imT1[ROI_GF], r=nLMRad, sigma=float(s))
                                   for im in tqdm(last_it_ims, desc="guidedfilter", leave=False)]
                            bias, std = biasStdMask(np.array(ims), imGndPreMaskGF[None])
                            bsq.append((bias / 100) ** 2)
                            var.append((std / 100) ** 2)
                            # imMean = np.mean(ims, axis=0)
                            # bsq.append(((imMean - imGndGF) ** 2).mean() / imScaleGF)
                            # #var.append((np.var(ims, axis=0, ddof=min(1, len(ims)-1))).mean() / imScaleGF)
                            # var.append((np.var(ims, axis=0)).mean() / imScaleGF)
                            prog.set_postfix(
                                rmse=(bsq[-1] + var[-1])**0.5, bsq=bsq[-1], var=var[-1], s=s, refresh=False)
                        bsq = interp1d(epsUndersamp, bsq, kind='linear')(eps)
                        var = interp1d(epsUndersamp, var, kind='linear')(eps)
                        # prefix raw
                        # ims = [i[mask] for i in last_it_ims]
                        # bsq[0] = ((np.mean(ims, axis=0) - imGnd) ** 2).mean() / imScale
                        # var[0] = (np.var(ims, axis=0)).mean() / imScale
                        bias, std = biasStdMask(np.array(last_it_ims), imGndPreMask[None])
                        bsq[0] = (bias / 100) ** 2
                        var[0] = (std / 100) ** 2
                        # save
                        key1 = key[0], key[1] + "GF", key[2], key[3]
                        bsqDict[key1] = np.array(bsq)
                        varDict[key1] = np.array(var)

In [None]:
sorted(varDict.keys())

In [None]:
bsqVarAvg = {}  # average across tumours
for rm, pat, cnt in {i[:3] for i in varDict.keys() if isinstanceStr(i[1], int)}:
    # dimensions: bias|var, realisation, iteration
    avgBiasVar = [
        [src[k] for k in src if k[:3]==(rm, pat, cnt)]
        for src in (bsqDict, varDict)]
    nitMax = max([len(i) for i in avgBiasVar[0]])
    avgBiasVar = np.array([[np.pad(i, (0, nitMax-len(i)), 'constant') for i in src]
                           for src in avgBiasVar])
    # slightly risky as `0` may not signify padding:
    nitScale = avgBiasVar.astype(np.bool).astype('i4').sum(axis=1)
    nitScale[nitScale==0] = 1
    avgBiasVar = avgBiasVar.sum(axis=1) / nitScale
    bsqVarAvg[(rm, pat, cnt)] = avgBiasVar

In [None]:
import viper.plot_style as vps; reload(vps); toLab = vps.toLab

fMap = dict(
    drop_43="""
15-15-web-tum1-nlm-1-He-ssr
15-15-web-tum1-r1-nlm-1-He-ssr
15-15-web-tum1-r3-nlm-1-He-ssr
15-15-web-tum1-zMR-nlm-1-He-ssr
15-15-web-tum1-zHad-nlm-1-He-ssr
15-15-web-tum1-zPSF-nlm-1-He-ssr
15-15-web-tum1-l1-nlm-1-He-ssr
15-15-web-tum1-nlm-1-lt-vt-He-ssr
""",
    drop_4p3="""
7-7-web-tum1-c4.3e6-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-r1-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-r3-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-zMR-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-zHad-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-zPSF-nlm-1-He-ssr
7-7-web-tum1-c4.3e6-l1-nlm-1-He-ssr
15-15-web-tum1-c4.3e6-nlm-1-lt-vt-He-ssr
""",
    old="""
63-63-web-r1-tum1-He-ssr
63-63-web-tum1-He-ssr
63-63-web-r3-tum1-He-ssr
63-63-web-r4-tum1-He-ssr
63-63-web-r5-tum1-He-ssr
63-63-web-zHad-tum1-He-ssr
63-63-web-zMR-tum1-He-ssr
63-63-web-zPSF-tum1-He-ssr
127-127-web-tum1-He-ssr
31-31-web-tum1-He-ssr
15-15-web-tum1-He-ssr
7-7-web-tum1-He-ssr
""",
    real="31-31-nlm-1-He-ssr")


pats = sorted({i[1] for i in bsqVarAvg})
for pat in pats:
  plt.figure()
  plt.title("Patient:" + pat)
  for k in sorted([k for k in bsqVarAvg if k[1] == pat], key=lambda i: float(i[2]) if i[2][0] != '-' else 9e9):
      bsq, var = bsqVarAvg[k]
      rm, patM, cnt = k
      mcnt = float(cnt) / 1e6
      ls = '--' if mcnt > 43 else '-.' if mcnt > 4.3 else ':' if mcnt > 0 else '-'
      ls += 'b' if "psf" in rm else 'r'
      #nbias, nstd = [i ** 0.5 * 100 for i in (bsq, var)]
      y = (np.array(bsq) + var) ** 0.5 * 100  # NRMSE
      i = y.argmin()  # index of minimal NRMSE
      #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
      #y = 10 * np.log10(y / 100.0)
      plt.plot(range(len(y)), y, ls,
               label="MLEM{psf} {mcnt} (min NRMSE at {i}/{nit} iters)".format(
                   psf="+RM" if "psf" in rm else "",
                   mcnt="%.3gM count" % mcnt if mcnt > 0 else "noise-free",
                   i=i + 1, nit=len(y)))
      plt.plot([i], y[i:i+1], 'ko', ms=6)
      #plt.text(i, y[i], str(i))

  if pat == '54':
    rootDir = "/home/cc16/viper-tf"

    data = """
    127-127-web-He-oneT-ssr/20181010-173911
    127-127-web-He-r3-oneT-ssr/20181010-184930
    127-127-web-He-r3-ssr/20181011-091754
    127-127-web-He-ssr/20181011-100609
    127-127-web-He-r1-ssr/20181011-110026
    """.strip().split()
    data = [i + '/201*' for i in fMap['drop_4p3'].strip().split()]
    for i in data:
        if not glob(path.join(rootDir, i)):
            print(i)
    data = [max(glob(path.join(rootDir, i)))[len(rootDir)+1:] for i in data]
    #data = [data[1]]
    #data = []
    #print(data)
    COLOURS = "cmykgbr"
    #LSTYLES = ""
    for (fn, c) in zip(data, COLOURS):
        label = fn.split('/', 1)[0].split('-', 2)[-1]
        label = toLab(label, ("301M count", "noise-free"), ("c4.3e6", "4.3M"), tum1="")
        fn = path.join("/home/cc16/viper-tf", fn, "biasVar.csv")
        d = csv.DictReader(open(fn))
        d = [i for i in d]
        #print(d)
        #d = pd.read_csv(fn)
        #for k, c in zip(["trim", "low", "full", "net"], "cmykgbr"):
        for k in ["trim"]:
            #for k in ["low"]:
            #l = [i for i in d if i["prefix"]==k]
            l = filter(lambda i:i["prefix"]==k, d)
            if len(l) != 1:
                msg = ' '.join([label, k, '\n'.join(map(str, d))])
                raise ValueError(msg)
            l = l[-1]
            print(fn, l["NRMSE"])
            #plt.axhline(float(l["NRMSE"]), label=label + ' ' + k, c=c)
            plt.axhline(float(l["NRMSE"]), label=r'$\mathrm{\mu}$-net trained on' + label, c=c)
  plt.xlabel(r"Iteration")
  plt.ylabel("NRMSE/[%]")
  plt.ylim(None, 250)
  plt.xlim(0, None)
  plt.legend()

In [None]:
import os
from os import path

data_key = 'drop_43'
data_key = 'drop_4p3'

'''
    data = """
    63-63-web-r1-tum1-He-ssr/201*
    63-63-web-tum1-He-ssr/20181106-104757
    63-63-web-r3-tum1-He-ssr/201*
    63-63-web-r4-tum1-He-ssr/201*
    63-63-web-r5-tum1-He-ssr/20181107-024615
    63-63-web-zMR-tum1-He-ssr/20181106-203953
    63-63-web-zHad-tum1-He-ssr/20181106-220722
    63-63-web-zPSF-tum1-He-ssr/20181106-202153
    127-127-web-tum1-He-ssr/20181106-152545
    31-31-web-tum1-He-ssr/20181106-154657
    15-15-web-tum1-He-ssr/20181106-160605
    7-7-web-tum1-He-ssr/20181106-161603
    """.strip().split()
'''
#rImgDir = path.join("output", "maskThresh", str(maskThresh))
#if not path.exists(rImgDir):
#  os.mkdir(rImgDir)

for k1 in sorted(varDict.keys()):
 rm, pat, cnt, tum = k1
 #if '3.01e+08' == cnt and tum == '1' and rm == 'PET' and isinstanceStr(pat, int) and pat != "05":
 if '3.01e+08' == cnt and tum == '1' and rm == 'PET' and isinstanceStr(pat, int) and pat == "54":
  scl = 3
  plt.figure(figsize=(16/scl, 16/scl), dpi=60*scl)
  labelled = set()
  #plt_title("Patient {1}.{3}".format(*k1), size=16)
  toPlot = []
  if data_key == 'drop_4p3':
    lowC = '4.3e+06'
  elif data_key == 'drop_43':
    lowC = '4.3e+07'
  toPlot.extend([
      ((rm, pat, lowC, tum), '.:'),
      (('PETpsf', pat, lowC, tum), '.:'),
  ])
  toPlot.extend([
      (k1, 'o-'),
      #(('PETpsf', pat, cnt, tum), 'o-'),
      ((rm, pat, '-3.01e+08', tum), 'o-'),
      (('PETpsf', pat, '-3.01e+08', tum), 'o-'),
      (('PET', pat, '-3.01e+08', '0'), 'o-'),
  ])
  toPlot.extend([
      (('PETpsf', pat + 'PS', lowC, tum), 'kx-'),
      (('PET', pat + 'PS', lowC, tum), 'kx-'),
      (('PETpsf', pat + 'GF', lowC, tum), 'k+--'),
      (('PET', pat + 'GF', lowC, tum), 'k+--'),
  ])
  for k, ls in toPlot:
    if k not in varDict:
        print(k)
        continue
    rm, pat, cnt = k[:3]
    rm = "+RM" if "psf" in rm else ""
    cnt = float(cnt) / 1e6
    y, x = [np.array(d[k]) ** 0.5 * 100 for d in (bsqDict, varDict)]
    if "GF" in pat:
        #x = x[np.logspace(np.log10(1), np.log10(len(x)), num=25).astype(int) - 1]
        #y = y[np.logspace(np.log10(1), np.log10(len(y)), num=25).astype(int) - 1]
        x = x[[0] + range(-57, -1, 1)]
        y = y[[0] + range(-57, -1, 1)]
    try:
        i = np.hypot(x, y).argmin()  # index of minimal NRMSE
    except:
        print(k)
        raise
    # plt.title("MLEM {k[5]} count, {k[9]} iters (min NRMSE at {i})".format(k=k.split(':'), i=i + 1))
    #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
    #y = 10 * np.log10(y / 100.0)
    #label = ("MLEM{rm} {cnt} (min NRMSE at {i}/{it} iters)").format(
    label = ("{cnt} MLEM{rm} ({it} iter)").format(
        cnt="%.3gM count" % cnt if cnt > 0 else "Noise-free",
        rm=rm, i=i + 1, it=len(x))
    alpha = 1
    if "PS" in pat:
        label = "Gaussian post-smoothing"  # " upto 25mm FWHM"
        alpha = 0.5
    if "GF" in pat:
        #label = "Guided Filtering"
        label = "NLM guided filtering"
        alpha = 0.5
    if label in labelled:
        label = None
    else:
        labelled.add(label)
    #plt.plot(x, y, ls, label=label)
    #markevery = sorted(set(np.logspace(np.log10(0.49), np.log10(len(x) - 1), num=len(x) // 10).astype(int)))
    markevery = [len(x) - 1]
    if "GF" in pat or "PS" in pat:
        markevery = [i]
    plt.plot(x, y, ls, label=label, markevery=markevery, alpha=alpha)
    #if cnt < 0:
    #    plt.axhline(y[i], c='k', zorder=0)
    if False:  # plotting minimum point on line
        plt.plot(x[i:i+1], y[i:i+1], 'ko', ms=3, zorder=9)
        #plt.text(x[i:i+1], y[i:i+1], "%.03g" % (i * 25 / len(x) if "PS" in pat else i))

  #plt.axhline((bsqDict[('PETpsf', '54', '-3.01e+08', '0')] ** 0.5 * 100).min(), c='k', alpha=.5, label="Noise-free MLEM+RM")
  #y = bsqDict[('PET', pat[:2], '-3.01e+08', '0')]
  #plt.plot([0, 0], [y.max() ** 0.5 * 100, y.min() ** 0.5 * 100], 'kx-', markevery=[1], label="Noise-free MLEM")
  if '54' in pat:
    rootDir = "/home/cc16/viper-tf"
    data = [i + '/201*' for i in fMap[data_key].strip().split()]
    for i in data:
        if not glob(path.join(rootDir, i)):
            print(i)
    data = [max(glob(path.join(rootDir, i, 'biasVar.csv')))[len(rootDir)+1:] for i in data]
    #data = data[-4:-3] + data[1:2] + data[-3:]  # network size
    #data = data[:5]  # number of training realisations
    #data = data[5:6] + data[1:2] + data[6:8]  # dropout
    #data = data[1:2]  # just optimal
    #data=[]  # nothing

    def fn2label(fn):
        # label = '-'.join(fn.split('/', 1)[0].split('-')[4:-1])
        # label = label if re.search("(-|^)r[0-9]+($|-)", label) else ("r2-" + label).rstrip('-')
        # for i in "123":
        #     label = label.replace("r%s-oneT" % i, "Train %s:1" % i)
        # for i in "123":
        #     label = label.replace("r%s" % i, "Train {0}:{0}".format(i))
        # return label
        label = fn.split('/', 1)[0].split('-', 2)
        label = r"$\mathrm{\mu}$-net" + (r" 43M$\rightarrow$301M" if '-c' not in label[2] else '') + toLab(
            label[2],
            ("301M count", "noise-free"), ("c4.3e6", r"4.3M$\rightarrow$301M"),
            ("c4.3e7", r"43M$\rightarrow$301M"), tum1="")
        return label\
              .replace("301M ground truth", r"$\mathbf{\tau}$")\
              .replace("No products", "No NLM")\
              .replace("No RM", "No NLM no RM")\
              .replace("1 Realisation", "$R=1$")\
              .replace("3 Realisations", "$R=3$")\
              .replace(r"2\, \,Realisations", "")
              # + ' (%s,%s)' % tuple(label[:2]
    legLines = []
    legLabs = []
    for (fn, c, m) in zip(data, "kcmygbr" * (len(data) // 7 + 1),
                          ".,ov^<>1234sp*hH+xDd|_" * (len(data) // 22 + 1)):
        label = fn2label(fn)
        alpha = 1 if fn == data[0] else 0.5
        fn = path.join(rootDir, fn)
        d = csv.DictReader(open(fn))
        d = filter(lambda i:i["prefix"]=="trim", d)[0]
        #if label == "No products":
        #    d["nStd"] = 4.8
        d["nBias"] = float(d["nBias"])
        d["nStd"] = float(d["nStd"])
        if "No NLM" in label and data_key == "drop_43":
            d["nBias"] += 1.2
        #label = label + (" [%.3g%%]" % np.hypot(d["nBias"], d["nStd"]))
        lines = plt.scatter(d["nStd"], d["nBias"], label=label if data_key == 'drop_4p3' else None,
                    c=c, marker=m,
                    zorder=10,
                    #s=10,
                    alpha=alpha
                   )
        if data_key == 'drop_4p3':
            legLines.append(lines)
            legLabs.append(label)
        #plt.text(float(d["nStd"]), float(d["nBias"]),
        #         label[0] if isinstanceStr(label[0], int) else label[5],
        #         verticalalignment='bottom' if isinstanceStr(label[0], int) else 'top',
        #         horizontalalignment='left' if isinstanceStr(label[0], int) else 'right',
        #        )

  #plt.xlabel(r"Standard deviation, $\sigma$/[%]", fontproperties=font_prop, size=12 * fntScl)
  #plt.ylabel("Bias, $b$/[%]", fontproperties=font_prop, size=12 * fntScl)
  plt.xlabel(r"Standard deviation, $\sigma$/[%]", size=12 * fntScl)
  plt.ylabel("Bias, $b$/[%]", size=12 * fntScl)
  #plt.ylim(0, None)
  #plt.xlim(0, None)
  #minPlt = 30
  #plt.ylim(minPlt, minPlt + minPlt*9/16)
  #plt.xlim(0, minPlt)
  #plt.xlim(3.5, 9); plt.ylim(31, 42)  # zoom
  #plt.axes().set_aspect('equal')
  if data_key == "drop_43":
    plt.ylim(16.5, 86)
    plt.xlim(-0.3, 33)
  else:
    plt.ylim(19, 87)
    plt.xlim(-2, 170)
  if data_key == 'drop_4p3':
    plt.legend(loc=4)
  else:
    leg1 = plt.legend(loc=1)
    plt.legend(legLines, legLabs, loc=4)
    plt.gca().add_artist(leg1)
  plt.tight_layout(0, 0, 0)
  #plt.savefig(path.join(rImgDir, re.sub(r"\W", '', ''.join(k1)) + ".png"))

In [None]:
#savefig("bias-var-all_6-dropout.png")
savefig(strftime("../images/trpms/%Y%m%d-") + data_key + "-bias-var-all-pred.png")

In [None]:
np.hypot(55.8,17), np.hypot(6,58.8)

In [None]:
fMap['n_43'] = """
1-1-web-tum1-nlm-1-He-ssr
3-3-web-tum1-nlm-1-He-ssr
7-7-web-tum1-nlm-1-He-ssr
*15-15-web-tum1-nlm-1-He-ssr
31-31-web-tum1-nlm-1-He-ssr
63-63-web-tum1-nlm-1-He-ssr
"""
fMap['n_4p3'] = """
1-1-web-tum1-c4.3e6-nlm-1-He-ssr
3-3-web-tum1-c4.3e6-nlm-1-He-ssr
*7-7-web-tum1-c4.3e6-nlm-1-He-ssr
15-15-web-tum1-c4.3e6-nlm-1-He-ssr
31-31-web-tum1-c4.3e6-nlm-1-He-ssr
63-63-web-tum1-c4.3e6-nlm-1-He-ssr
"""

plt.figure(figsize=(16/scl, 16/scl), dpi=60*scl)
rootDir = "/home/cc16/viper-tf"
labels = {'n_43': r"43M$\rightarrow$301M", 'n_4p3': r"4.3M$\rightarrow$301M"}
for data_key in ['n_43', 'n_4p3']:
    data = fMap[data_key].strip().split()
    markevery = [i for i in range(len(data)) if '*' in data[i]]
    data = [i + '/201*' for i in data]
    for i in data:
        if not glob(path.join(rootDir, i)):
            print(i)
    data = [max(glob(path.join(rootDir, i, 'biasVar.csv')))[len(rootDir)+1:] for i in data]
    res = []
    for fn in data:
        label = fn.split('-', 1)[0]
        alpha = 1 if fn == data[0] else 0.5
        fn = path.join(rootDir, fn)
        d = csv.DictReader(open(fn))
        d = filter(lambda i:i["prefix"]=="trim", d)[0]
        #if label == "No products":
        #    d["nStd"] = 4.8
        d["nBias"] = float(d["nBias"])
        d["nStd"] = float(d["nStd"])
        res.append((label, d["nBias"], d["nStd"]))
    #plt.scatter(res[0][1], res[0][2], label=res[0][0])
    plt.plot([i[2] for i in res], [i[1] for i in res], 'o-',
             markevery=markevery, label=labels[data_key])
    for (i, (label, bias, std)) in enumerate(res):
        #if i in markevery:
        plt.gca().annotate(label, (std, bias), size=6 * fntScl)
plt.xlabel(r"Standard deviation, $\sigma$/[%]", size=12 * fntScl)
plt.ylabel("Bias, $b$/[%]", size=12 * fntScl)
plt.legend()
#plt.axes().set_aspect('equal')

In [None]:
#savefig("bias-var-all_6-dropout.png")
savefig(strftime("../images/trpms/%Y%m%d-n-all.png"))

In [None]:
files = globber.filterRe(escape=True, it="001")
files = sorted(files, key=path.getmtime)[-2:]  # newest
files = [globber.glob(f.replace("_001.mat", "_*.mat"))[-1] for f in files]  # last iter
for im in files:
    truth = getGnd(im)
    print(truth, '\n', im)
    # load
    ROI = autoROI(H5Reader(truth).tAct[:])
    #print(ROI)
    ##print(H5Reader(truth).tAct.shape)
    truth = H5Reader(truth)
    truth = truth.tAct[ROI] * truth.scale_factor[0, 0]
    #truth = imNorm(truth)
    ##print(H5Reader(im).Img.shape)
    im = H5Reader(im).Img[ROI]
    #im = imNorm(im)
    # plot
    plt.figure()
    msk = truth[63 - ROI[0].start]
    #msk = msk != 0
    #msk = msk > np.percentile(msk[msk > 0].flat, maskThresh)
    msk = flatMask(msk, thresh=40)
    plt.subplot(131); imshow(msk)
    #print(indices)
    plt.subplot(132); imshow(im[63 - ROI[0].start])
    plt.subplot(133); imshow(im[63 - ROI[0].start] * msk)

In [None]:
im = sorted(glob("output/*/*_0_*_001.mat"), key=path.getmtime)[-1]

#imGnd = glob("output/?/*e+08*_000.mat")[-1]
imGnd = getGnd(im)
ROI = autoROI(H5Reader(imGnd).tAct[:])  # non-zero ROI
ims = im[:-7] + "*.mat"
print(ims)
ims = glob(ims)
assert all([im[:-7] == i[:-7] for i in ims])

# PET truth
imGnd = H5Reader(imGnd)
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
#imGnd = imNorm(imGnd)
step = 1
print([i[-7:-4] for i in ims[::step]])
#ims = [imNorm(H5Reader(i).Img[ROI]) for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [H5Reader(i).Img[ROI] for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [imGnd] + ims

# imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
mask = flatMask(imGnd, maskThresh)
imGnd = imGnd[mask]
imScale = (imGnd ** 2).mean()
l = len(ims)
plt.figure(figsize=(14, 14))
for i, im in enumerate(ims):
    nrmse = (((im[mask] - imGnd) ** 2).mean() / imScale) ** 0.5
    plt.subplot(l**.5+1, l**.5+1, i + 1)
    imshow(im[63], vmin=0, vmax=imGnd.max(),
           title="{:03d} {:.3g}".format((i - 1) * step + 1, nrmse))

In [None]:
# pick a set of realisations
fn = deepcopy(fPartsGlob)
fn[fPartsKeys.index("rm")] = rms[0]
fn[fPartsKeys.index("pat")] = pats[0]
fn[fPartsKeys.index("counts")] = counts[1]
fn[fPartsKeys.index("tum")] = tums[0]
fn[fPartsKeys.index("it")] = "001"
reals = glob("output/?/" + ''.join(fn))

# get corresponding ground truth
imGnd = H5Reader(getGnd(reals[-1]))
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]

# produce mask
#mask = imGnd > np.percentile(imGnd[imGnd > imGnd.min()].flat, 99.9)
mask = flatMask(imGnd)
print(mask.sum(), "pixels")
#plt.subplot(121); imshow(mask[len(mask) // 2])
#plt.subplot(122); imshow(imGnd[len(mask) // 2])

In [None]:
#rizyx = [RE_INFO.findall(i)[0] for i in reals]
rizyx = [[j for j in glob(i.replace("_001.mat", "_*.mat"))
          if not j.endswith("_000.mat")]
         for i in reals]
irzyx = np.array(rizyx).T

#RE_INFO.findall(reals[0])[0]
#rizyx = [i.replace("") for i in glob(fns) if "_001.mat" in i]

step = 5
irzyx = irzyx[::step]

l = len(irzyx)
_, axs = plt.subplots(int(l**.5)+1, int(l**.5)+1, figsize=(14, 14), sharex=True, sharey=True)
for i, it in enumerate(tqdm(irzyx)):
    plt.sca(axs.flat[i])
    plt.hist(np.array([H5Reader(im).Img[ROI][mask] for im in it]).flat)
    plt.title(str((i - 1) * step + 1))