In [None]:
from __future__ import print_function, division
%matplotlib notebook
import matplotlib.pyplot as plt

import re
from copy import deepcopy
import numpy as np
import cPickle as pickle
import re
from os import path
import csv
#from viper.constants import SIGMA2FWHM_MMR
SIGMA2FWHM_MMR = (8 * np.log(2)) ** .5 * 2.08626
from scipy.ndimage.filters import gaussian_filter as gauss

import warnings
with warnings.catch_warnings():
    from tqdm import TqdmExperimentalWarning
    warnings.simplefilter("ignore", category=TqdmExperimentalWarning)
    from tqdm.auto import tqdm, trange
warnings.filterwarnings("ignore", category=RuntimeWarning, message="numpy.dtype size changed", module="h5py")
from caspyr.utils import H5Reader, Globber, glob

from viper.plot_style import toLab

In [None]:
def imshow(im, cmap="Greys", origin="lower", title=None, **kwargs):
    ax = plt.imshow(im, cmap=cmap, origin=origin, **kwargs)
    plt.axis("off")
    if title:
        plt.title(title)

def autoROI(im):
    """return bounding box tuple(slice) of non-zero elements"""
    inds = np.array(im.nonzero())
    mins = inds.min(axis=1)
    maxs = inds.max(axis=1)
    return tuple(slice(i, j) for i, j in zip(mins, maxs))

def imNorm(im, assumeZeroMin=True):
    """in-place, average 1 count/voxel"""
    if not assumeZeroMin:
        im -= im.min()
    im *= im.size / im.max()
    return im

def flatMask(im, thresh=50, hw=1):
    """
    thresh  : int, percentile
    hw  : int, half width
    """
    from scipy.signal import convolve
    m = np.ones([2*hw+1] * im.ndim) * -1
    m[hw, hw] = m.size - 1
    edges = np.abs(convolve(im, m, mode='same'))
    msk = edges < np.percentile(edges.flat, thresh)  # high pass
    msk *= im > np.percentile(im[im > 0].flat, 100 - thresh)  # low pass
    return msk

In [None]:
globber = Globber(
    ("meta0", "output/[0-9]+/brainweb_", "output/*/brainweb_"),
    ("rm", "PET|PETpsf", "PET*"),
    ("meta1", "_[0-9]+_subject_", "_*_subject_"),
    ("pat", "[0-9]+", "*"),
    ("meta6", "-S_", "-S_"),
    ("sigma", "[0-9.e+]+", "*"),
    ("meta7", "-NP_", "-NP_"),
    ("fsPET", "[0-9.e+]+", "*"),
    ("meta8", "-NT1_", "-NT1_"),
    ("fsT1", "[0-9.e+]+", "*"),
    ("meta2", "-C_", "-C_"),
    ("counts", "-?[0-9.e+]+", "*"),
    ("meta3", "_t", "_t"),
    ("tum", "-?[0-9]+", "*"),
    ("meta4", "_", "_"),
    ("it", "[0-9]+", "*"),
    ("meta5", r"\.mat", ".mat"))

In [None]:
maskThresh = 50
PAD = 3
ROI_BW = [(33, 93), (115, 230), (115, 230)]  # zxy
ROI_BW = tuple(slice(i + PAD, j - PAD) for (i, j) in ROI_BW)

In [None]:
def getGnd(fn, debug=False):
    # ground truth image
    counts = globber.split(fn)[globber["counts"]]\
        .replace('-', "-?").replace('+', '\\+').replace('.', '\\.')
    files = globber.filterReUnlike(fn, it="000") or globber.filterReUnlike(
        fn, meta0=None, meta1=None, rm=None, it="000", counts=counts)  # counts=None
    if not files:
        raise IOError("Could not find truth for:%s" % fn)
    if debug:
        #ROI = autoROI(H5Reader(files[0]).tAct[:])  # non-zero ROI
        ROI = ROI_BW
        imGnd = H5Reader(files[0]).tAct[ROI]
        tmp = [(imGnd == H5Reader(i).tAct[ROI]).all() for i in tqdm(files, desc="debug")]
        #print(len(tmp), files)
        if not all(tmp):
            raise ValueError("mismatched ground truths:" + p)
    return files[0]

In [None]:
#globber.partsVals[:-2] + ['000 -> 300', globber.partsVals[-1]]
rms = globber.partsVals[globber["rm"]]
pats = globber.partsVals[globber["pat"]]  # int
counts = globber.partsVals[globber["counts"]]  # float
tums = globber.partsVals[globber["tum"]]  # int
its = globber.partsVals[globber["it"]]  # int
print(rms, pats, counts, tums)  # , its

In [None]:
pats = """
54
""".strip().split()
counts = ["%.3g" % (i * 1e6) for i in [43, 301]]
tums=['1']
#brainweb_PET_*_54-*_4.3*_t1_001.mat

In [None]:
print(rms, pats, counts, tums)  # , its

In [None]:
varDict = {}
bsqDict = {}

In [None]:
#outputPkl = "output/biasVar.pkl"
#outputPkl = "output/biasVar-flat.pkl"
outputPkl = "output/biasVar-brainweb.pkl"

In [None]:
with open(outputPkl) as fd:
    bsqDict, varDict = pickle.load(fd)

In [None]:
#ROI = slice(0, None), slice(100, 250), slice(100, 250)
for rm in tqdm(rms, desc="RM"):
    for pat in tqdm(pats, desc="Patient"):
        for cnt in tqdm(counts, desc="Counts"):
            for tum in tqdm(tums, desc="Tumours"):
                files = globber.filterRe(
                    escape=True,
                    rm=rm, pat=pat, counts=cnt, tum=tum,
                    it="001")
                if not files:
                    print("cannot find globber files", rm, pat, cnt, tum)
                    continue

                # per-iteration results vectors
                bsq = []
                var = []
                # ground truth image`
                imGnd = H5Reader(getGnd(files[0]))
                #ROI = autoROI(imGnd.tAct[:])  # non-zero ROI
                ROI = ROI_BW
                imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
                #imGnd = imNorm(imGnd)
                # imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
                ####mask = flatMask(imGnd, maskThresh)
                mask = imGnd > 0
                imGnd = imGnd[mask]
                imScale = (imGnd ** 2).mean()
                # pre-allocate one set of realisations [R, X, Y, Z]
                #ims = np.zeros((len(parts),) + imGnd.shape, dtype=imGnd.dtype)
                #ims = np.zeros((len(files), imGnd.size), dtype=imGnd.dtype)
                with tqdm(total=len(its),
                          desc="%d reals, %.3g scale" % (len(files), imGnd.sum())) as tIters:
                  it = 1
                  while True:
                    usedfiles = filter(path.exists, [p[:-7] + "%03d.mat" % it for p in files])
                    if not usedfiles:
                        if it > 1:
                            usedfiles = filter(path.exists, [p[:-7] + "%03d.mat" % (it - 1) for p in files])
                            last_it_ims = [H5Reader(p).Img[ROI] for p in usedfiles]
                        break
                    #if len(ims) != len(usedfiles):
                    #    log.warn("altering preallocation")
                    #    ims = np.zeros((len(usedfiles), imGnd.size), dtype=imGnd.dtype)
                    #for i, p in enumerate(tqdm(usedfiles, desc="realisations", disable=True)):
                    #    #ims[i] = imNorm(H5Reader(''.join(p)).Img[ROI])[mask]
                    #    ims[i] = H5Reader(p).Img[ROI][mask]
                    HACK_SKIP = False
                    if not HACK_SKIP:
                        last_it_ims = [H5Reader(p).Img[ROI] for p in usedfiles]
                        ims = [i[mask] for i in last_it_ims]
                        imMean = np.mean(ims, axis=0)
                        #imMean = ims.mean(axis=0)
                        #bsq.append(((imMean - imGnd) / imGnd).mean())
                        #var.append((np.std(ims, axis=0, ddof=1) / imMean).mean())
                        bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                        #var.append((np.var(ims, axis=0, ddof=min(1, len(ims)-1))).mean() / imScale)
                        var.append((np.var(ims, axis=0)).mean() / imScale)
                        tIters.set_postfix(
                            # N.B: mseTest will be nonzero due to var(..., ddof=1)
                            # mseTest = ((ims - imGnd) ** 2).mean() / imScale - bsq[-1] - var[-1],
                            mse=bsq[-1] + var[-1], bsq=bsq[-1], var=var[-1], refresh=False)
                    tIters.update()
                    it += 1
                """usedfiles = filter(path.exists, [p[:-7] + "%03d.mat" % 300 for p in files])
                bsq = range(300)
                ims = [H5Reader(p).Img[ROI][mask] for p in usedfiles]"""
                #key = list(globber.valRe.findall(files[0])[0])
                #key[globber["it"]] = ""
                #key = ':'.join(key)
                #key = ':'.join((rm, pat, cnt, tum))
                key = rm, pat, cnt, tum
                if not HACK_SKIP:
                    bsqDict[key] = np.array(bsq)
                    varDict[key] = np.array(var)
                # post-smoothing
                if 0 < float(cnt) < 301e6:  # and "psf" in rm:
                    num = it - 1  # len(bsq)
                    last_it_ims = np.array(last_it_ims)
                    bsq = []
                    var = []
                    for s in tqdm(np.linspace(0, 25, num=num) / SIGMA2FWHM_MMR, desc="PS"):
                        ims = [gauss(im, s)[mask] for im in last_it_ims]
                        #if np.all(last_it_ims == ims):
                        #    print(s)
                        imMean = np.mean(ims, axis=0)
                        bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                        #var.append((np.var(ims, axis=0, ddof=min(1, len(ims)-1))).mean() / imScale)
                        var.append((np.var(ims, axis=0)).mean() / imScale)
                    key = key[0], key[1] + "PS", key[2], key[3]
                    bsqDict[key] = np.array(bsq)
                    varDict[key] = np.array(var)

In [None]:
sorted(varDict.keys())

In [None]:
bsqVarAvg = {}  # average across tumours
for rm, pat, cnt in {i[:3] for i in varDict.keys() if "PS" not in i[1]}:
    # dimensions: bias|var, realisation, iteration
    avgBiasVar = [
        [src[k] for k in src if k[:3]==(rm, pat, cnt)]
        for src in (bsqDict, varDict)]
    nitMax = max([len(i) for i in avgBiasVar[0]])
    avgBiasVar = np.array([[np.pad(i, (0, nitMax-len(i)), 'constant') for i in src]
                           for src in avgBiasVar])
    # slightly risky as `0` may not signify padding:
    nitScale = avgBiasVar.astype(np.bool).astype('i4').sum(axis=1)
    nitScale[nitScale==0] = 1
    avgBiasVar = avgBiasVar.sum(axis=1) / nitScale
    bsqVarAvg[(rm, pat, cnt)] = avgBiasVar

In [None]:
import viper.plot_style as vps; reload(vps); toLab = vps.toLab
#from viper.utils.stats import movingAvg, nrmse, biasStd, biasStdMask
import viper.utils.stats as vus; reload(vus); nrmse=vus.nrmse; biasStd = vus.biasStd  # , biasStdMask

In [None]:
pats = sorted({i[1] for i in bsqVarAvg})
for pat in pats:
  plt.figure()
  plt.title("Patient:" + pat)
  for k in [k for k in bsqVarAvg if k[1] == pat]:
      bsq, var = bsqVarAvg[k]
      rm, patM, cnt = k
      mcnt = float(cnt) / 1e6
      ls = '--' if mcnt > 43 else ':' if mcnt > 0 else '-'
      ls += 'b' if "psf" in rm else 'r'
      #nbias, nstd = [i ** 0.5 * 100 for i in (bsq, var)]
      y = (np.array(bsq) + var) ** 0.5 * 100  # NRMSE
      i = y.argmin()  # index of minimal NRMSE
      #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
      #y = 10 * np.log10(y / 100.0)
      plt.plot(range(len(y)), y, ls,
               label="MLEM{psf} {mcnt:.0f}M count (min NRMSE at {i}/{nit} iters)".format(
                   psf="+RM" if "psf" in rm else "", mcnt=mcnt, i=i + 1, nit=len(y)))
      plt.plot([i], y[i:i+1], 'ko', ms=6)
      #plt.text(i, y[i], str(i))

  if pat == '54':
    SIGMA2FWHM_MMR
    """
    127-127-web-He-oneT-ssr/20181010-173911
    127-127-web-He-r3-oneT-ssr/20181010-184930
    127-127-web-He-r3-ssr/20181011-091754
    127-127-web-He-ssr/20181011-100609
    127-127-web-He-r1-ssr/20181011-110026
    """
    data = """
    63-63-web-r1-tum1-He-ssr/20181106-204916
    63-63-web-tum1-He-ssr/20181106-104757
    63-63-web-r3-tum1-He-ssr/20181106-205550
    63-63-web-zHad-tum1-He-ssr/20181106-220722
    63-63-web-zMR-tum1-He-ssr/20181106-203953
    63-63-web-zPSF-tum1-He-ssr/20181106-202153
    """.strip().split()
    data = [data[1]]
    #data = []
    COLOURS = "cmykgbr"
    #LSTYLES = ""
    for (fn, c) in zip(data, COLOURS):
        label = fn.split('/', 1)[0].split('-', 2)[-1]
        label = toLab(label, tum1="")
        fn = path.join("/home/cc16/viper-tf", fn, "biasVar.csv")
        d = csv.DictReader(open(fn))
        d = [i for i in d]
        # for k, c in zip(["trim", "low", "full", "net"], "cmykgbr"):
        for k in ["trim"]:
            l = [i for i in d if i["prefix"]==k]
            if len(l) != 1:
                msg = ' '.join([label, k, '\n'.join(map(str, d))])
                raise ValueError(msg)
            #plt.axhline(float(l[-1]["NRMSE"]), label=label + ' ' + k, c=c)
            plt.axhline(float(l[-1]["NRMSE"]), label=r'$\mu$Net trained on ' + label, c=c)
  plt.xlabel(r"Iteration")
  plt.ylabel("NRMSE/[%]")
  plt.ylim(None, 46)
  plt.xlim(0, None)
  plt.legend()

In [None]:
import os
from os import path

rImgDir = path.join("output", "maskThresh", str(maskThresh))
if not path.exists(rImgDir):
  os.mkdir(rImgDir)

for k1 in sorted(varDict.keys()):
 rm, pat, cnt, tum = k1
 if '3.01e+08' == cnt and tum == '1' and rm == 'PET' and "PS" not in pat and pat != "05":
  plt.figure(figsize=(16/2, 9/2), dpi=60*2)
  plt.title("Patient {1}.{3}".format(*k1))
  for k, ls in zip([k1,
                    (rm, pat, '4.3e+07', tum),
                    #('PETpsf', pat, cnt, tum),
                    ('PETpsf', pat, '4.3e+07', tum),
                    (rm, pat, '-3.01e+08', tum),
                    ('PETpsf', pat, '-3.01e+08', tum),
                    ('PET', pat + 'PS', '4.3e+07', tum),
                    ('PETpsf', pat + 'PS', '4.3e+07', tum),
                   ],
                   ['o-',
                    '.-',
                    #'o-',
                    '.-', 'o-', 'o-', 'k-', 'k-']):
    if k not in varDict:
        print(k)
        continue
    y, x = [np.array(d[k]) ** 0.5 * 100 for d in (bsqDict, varDict)]
    try:
        i = np.hypot(x, y).argmin()  # index of minimal NRMSE
    except:
        print(k)
        raise
    # plt.title("MLEM {k[5]} count, {k[9]} iters (min NRMSE at {i})".format(k=k.split(':'), i=i + 1))
    #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
    #y = 10 * np.log10(y / 100.0)
    rm, pat, cnt = k[:3]
    rm = "+RM" if "psf" in rm else ""
    cnt = float(cnt) / 1e6
    #label = ("MLEM{rm} {cnt} (min NRMSE at {i}/{it} iters)").format(
    label = ("{cnt} MLEM{rm}").format(
        cnt="%.0fM count" % cnt if cnt > 0 else "noise-free",
        rm=rm, i=i + 1, it=len(x))
    if "PS" in pat:
        label = "Gaussian post-smoothing"  # " upto 25mm FWHM"
        if rm == "+RM":
            label = None
    #plt.plot(x, y, ls, label=label)
    markevery = sorted(set(np.logspace(np.log10(0.49), np.log10(len(x) - 1), num=len(x) // 10).astype(int)))
    plt.plot(x, y, ls, label=label, markevery=markevery)
    if cnt < 0:
        plt.axhline(y[i], c='k', zorder=0)
    plt.plot(x[i:i+1], y[i:i+1], 'ko', ms=3, zorder=9)
    #plt.text(x[i:i+1], y[i:i+1], "%.03g" % (i * 25 / len(x) if "PS" in pat else i))

  if '54' in pat:
    data = """
    63-63-web-r1-tum1-He-ssr/20181106-204916
    63-63-web-tum1-He-ssr/20181106-104757
    63-63-web-r3-tum1-He-ssr/20181106-205550
    63-63-web-r4-tum1-He-ssr/20181107-012338
    63-63-web-r5-tum1-He-ssr/20181107-024615
    63-63-web-zHad-tum1-He-ssr/20181106-220722
    63-63-web-zMR-tum1-He-ssr/20181106-203953
    63-63-web-zPSF-tum1-He-ssr/20181106-202153
    127-127-web-tum1-He-ssr/20181106-152545
    31-31-web-tum1-He-ssr/20181106-154657
    15-15-web-tum1-He-ssr/20181106-160605
    7-7-web-tum1-He-ssr/20181106-161603
    """.strip().split()
    #data = data[1:2] + data[-4:]
    data = data[:5]

    def fn2label(fn):
        # label = '-'.join(fn.split('/', 1)[0].split('-')[4:-1])
        # label = label if re.search("(-|^)r[0-9]+($|-)", label) else ("r2-" + label).rstrip('-')
        # for i in "123":
        #     label = label.replace("r%s-oneT" % i, "Train %s:1" % i)
        # for i in "123":
        #     label = label.replace("r%s" % i, "Train {0}:{0}".format(i))
        # return label
        label = fn.split('/', 1)[0].split('-')
        return toLab('-'.join(label[2:]), tum1="") + ' (%s,%s)' % tuple(label[:2])

    for (fn, c) in zip(data, "cmykgbr" * (len(data) // 7 + 1)):
        label = fn2label(fn)
        fn = path.join("/home/cc16/viper-tf", fn, "biasVar.csv")
        d = csv.DictReader(open(fn))
        d = [i for i in d if i["prefix"]=="maskScaled"][0]
        plt.scatter(float(d["nStd"]), float(d["nBias"]), label=label,
                    c=c, marker='x' if '-oneT' in fn else '*' if "\\bf" in label else '+',
                    zorder=10,
                    #s=100
                   )

  plt.xlabel(r"Standard deviation, $\sigma$/[%]")
  plt.ylabel("Bias/[%]")
  #plt.ylim(0, None)
  plt.xlim(0, None)
  #plt.axes().set_aspect('equal')
  plt.legend()

  plt.savefig(path.join(rImgDir, re.sub(r"\W", '', ''.join(k1)) + ".png"))

In [None]:
files = globber.filterRe(escape=True, it="001")
files = sorted(files, key=path.getmtime)[-2:]  # newest
files = [globber.glob(f.replace("_001.mat", "_*.mat"))[-1] for f in files]  # last iter
for im in files:
    truth = getGnd(im)
    print(truth, '\n', im)
    # load
    ROI = autoROI(H5Reader(truth).tAct[:])
    #print(ROI)
    ##print(H5Reader(truth).tAct.shape)
    truth = H5Reader(truth)
    truth = truth.tAct[ROI] * truth.scale_factor[0, 0]
    #truth = imNorm(truth)
    ##print(H5Reader(im).Img.shape)
    im = H5Reader(im).Img[ROI]
    #im = imNorm(im)
    # plot
    plt.figure()
    msk = truth[63 - ROI[0].start]
    #msk = msk != 0
    #msk = msk > np.percentile(msk[msk > 0].flat, maskThresh)
    msk = flatMask(msk, thresh=40)
    plt.subplot(131); imshow(msk)
    #print(indices)
    plt.subplot(132); imshow(im[63 - ROI[0].start])
    plt.subplot(133); imshow(im[63 - ROI[0].start] * msk)

In [None]:
im = sorted(glob("output/*/*_0_*_001.mat"), key=path.getmtime)[-1]

#imGnd = glob("output/?/*e+08*_000.mat")[-1]
imGnd = getGnd(im)
ROI = autoROI(H5Reader(imGnd).tAct[:])  # non-zero ROI
ims = im[:-7] + "*.mat"
print(ims)
ims = glob(ims)
assert all([im[:-7] == i[:-7] for i in ims])

# PET truth
imGnd = H5Reader(imGnd)
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
#imGnd = imNorm(imGnd)
step = 1
print([i[-7:-4] for i in ims[::step]])
#ims = [imNorm(H5Reader(i).Img[ROI]) for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [H5Reader(i).Img[ROI] for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [imGnd] + ims

# imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
mask = flatMask(imGnd, maskThresh)
imGnd = imGnd[mask]
imScale = (imGnd ** 2).mean()
l = len(ims)
plt.figure(figsize=(14, 14))
for i, im in enumerate(ims):
    nrmse = (((im[mask] - imGnd) ** 2).mean() / imScale) ** 0.5
    plt.subplot(l**.5+1, l**.5+1, i + 1)
    imshow(im[63], vmin=0, vmax=imGnd.max(),
           title="{:03d} {:.3g}".format((i - 1) * step + 1, nrmse))

In [None]:
# pick a set of realisations
fn = deepcopy(fPartsGlob)
fn[fPartsKeys.index("rm")] = rms[0]
fn[fPartsKeys.index("pat")] = pats[0]
fn[fPartsKeys.index("counts")] = counts[1]
fn[fPartsKeys.index("tum")] = tums[0]
fn[fPartsKeys.index("it")] = "001"
reals = glob("output/?/" + ''.join(fn))

# get corresponding ground truth
imGnd = H5Reader(getGnd(reals[-1]))
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]

# produce mask
#mask = imGnd > np.percentile(imGnd[imGnd > imGnd.min()].flat, 99.9)
mask = flatMask(imGnd)
print(mask.sum(), "pixels")
#plt.subplot(121); imshow(mask[len(mask) // 2])
#plt.subplot(122); imshow(imGnd[len(mask) // 2])

In [None]:
#rizyx = [RE_INFO.findall(i)[0] for i in reals]
rizyx = [[j for j in glob(i.replace("_001.mat", "_*.mat"))
          if not j.endswith("_000.mat")]
         for i in reals]
irzyx = np.array(rizyx).T

#RE_INFO.findall(reals[0])[0]
#rizyx = [i.replace("") for i in glob(fns) if "_001.mat" in i]

step = 5
irzyx = irzyx[::step]

l = len(irzyx)
_, axs = plt.subplots(int(l**.5)+1, int(l**.5)+1, figsize=(14, 14), sharex=True, sharey=True)
for i, it in enumerate(tqdm(irzyx)):
    plt.sca(axs.flat[i])
    plt.hist(np.array([H5Reader(im).Img[ROI][mask] for im in it]).flat)
    plt.title(str((i - 1) * step + 1))