In [None]:
from __future__ import print_function, division
from glob import glob
import re
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm.autonotebook import tqdm, trange
import numpy as np
import cPickle as pickle
import re

from caspyr.utils import H5Reader

%matplotlib notebook

In [None]:
# name, regex, glob
fParts = (
    ("meta0", ".*real_", "*real_"),
    ("rm", "PET|PETpsf", "PET*"),
    ("meta1", "_nosharp_[0-9]+_AD_", "_nosharp_*_AD_"),
    ("pat", "[0-9]+", "*"),
    ("meta2", "-C_", "-C_"),
    ("counts", "[0-9.e+]+", "*"),
    ("meta3", "_t", "_t"),
    ("tum", "[-0-9]+", "*"),
    ("meta4", "_", "_"),
    ("it", "[0-9]+", "*"),
    ("meta5", r"\.mat", ".mat"))
fPartsKeys = [i[0] for i in fParts]
fPartsVals = [i[1] for i in fParts]
fPartsGlob = [i[2] for i in fParts]
RE_INFO = ''.join('(' + i + ')' for i in fPartsVals)
RE_INFO = re.compile(RE_INFO)
FOLDERS = 7
maskThresh = 50

In [None]:
def autoROI(im):
    """return bounding box tuple(slice) of non-zero elements"""
    inds = np.array(im.nonzero())
    mins = inds.min(axis=1)
    maxs = inds.max(axis=1)
    return tuple(slice(i, j) for i, j in zip(mins, maxs))


def imNorm(im, assumeZeroMin=True):
    """in-place, average 1 count/voxel"""
    if not assumeZeroMin:
        im -= im.min()
    im *= im.size / im.max()
    return im


def flatMask(im, thresh=50, hw=1):
    """
    thresh  : int, percentile
    hw  : int, half width
    """
    from scipy.signal import convolve
    m = np.ones([2*hw+1] * im.ndim) * -1
    m[hw, hw] = m.size - 1
    edges = np.abs(convolve(im, m, mode='same'))
    msk = edges < np.percentile(edges.flat, thresh)  # high pass
    msk *= im > np.percentile(im[im > 0].flat, 100 - thresh)  # low pass
    return msk

In [None]:
files = sorted([fn for i in range(FOLDERS) for fn in glob("output/%d/*.mat" % i)])
parts = [RE_INFO.findall(fn)[0] for fn in files]

rms = sorted({i[fPartsKeys.index("rm")] for i in parts})
pats = sorted({i[fPartsKeys.index("pat")] for i in parts}, key=int)
counts = sorted({i[fPartsKeys.index("counts")] for i in parts}, key=float)
tums = sorted({i[fPartsKeys.index("tum")] for i in parts}, key=int)
its = sorted({i[fPartsKeys.index("it")] for i in parts}, key=int)
print(rms, pats, counts, tums)  # , its

In [None]:
#varDict = {}
#bsqDict = {}

# biasVar-flat
# biasVar
with open("output/biasVar.pkl") as fd:
    bsqDict, varDict = pickle.load(fd)

In [None]:
#ROI = slice(0, None), slice(100, 250), slice(100, 250)
fn = deepcopy(fPartsGlob)
for rm in tqdm(rms, desc="RM"):
    fn[fPartsKeys.index("rm")] = rm
    for pat in tqdm(pats, desc="Patient"):
        fn[fPartsKeys.index("pat")] = pat
        for cnt in tqdm(counts, desc="Counts"):
            fn[fPartsKeys.index("counts")] = cnt
            for tum in tqdm(tums, desc="Tumours"):
                fn[fPartsKeys.index("tum")] = tum
                files = sorted(glob("output/[%s]/%s" % (''.join(map(str, range(FOLDERS))), ''.join(fn))))
                if not files:
                    continue
                #print(len(files))
                #print('\n'.join(files))
                parts = [RE_INFO.findall(i)[0] for i in files]
                #print(len(parts))
                parts = sorted({i[:-2] + ("",) + i[-1:] for i in parts})  # ignore iteration
                parts = [list(p) for p in parts]
                #print('\n'.join(':'.join(i) for i in parts))
                # per-iteration results vectors
                bsq = []
                var = []
                # ground truth image
                p = deepcopy(parts[0])  # first realisation
                p[fPartsKeys.index("rm")] = fPartsGlob[fPartsKeys.index("rm")]
                p[fPartsKeys.index("counts")] = fPartsGlob[fPartsKeys.index("counts")]
                p[fPartsKeys.index("meta1")] = fPartsGlob[fPartsKeys.index("meta1")]
                p[fPartsKeys.index("it")] = "000"  # 0^th iter (metadata)
                p = ''.join(p)
                if not glob(p):
                    raise IOError("Could not find " + p)
                if False:  # TEST
                    ROI = autoROI(H5Reader(glob(p)[0]).tAct[:])  # non-zero ROI
                    imGnd = imNorm(H5Reader(glob(p)[0]).tAct[ROI])
                    tmp = [(imGnd == imNorm(H5Reader(i).tAct[ROI])).all() for i in glob(p)]
                    #print(len(tmp), glob(p))
                    assert all(tmp)
                    #continue
                p = glob(p)[0]
                ROI = autoROI(H5Reader(p).tAct[:])  # non-zero ROI
                imGnd = imNorm(H5Reader(p).tAct[ROI])  # PET truth
                # imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
                mask = flatMask(imGnd, maskThresh)
                imGnd = imGnd[mask]
                imScale = (imGnd ** 2).mean()
                # pre-allocate one set of realisations [R, X, Y, Z]
                #ims = np.zeros((len(parts),) + imGnd.shape, dtype=imGnd.dtype)
                ims = np.zeros((len(parts), imGnd.size), dtype=imGnd.dtype)
                with trange(1, 1 + (300 if "psf" in rm else 100),
                                 desc="%d reals iter" % len(parts)) as tIters:
                  for it in tIters:
                    for i, p in enumerate(tqdm(parts, desc="realisations", disable=True)):
                        p[-2] = "%03d" % it
                        ims[i] = imNorm(H5Reader(''.join(p)).Img[ROI])[mask]
                    imMean = ims.mean(axis=0)
                    #bsq.append(((imMean - imGnd) / imGnd).mean())
                    #var.append((np.std(ims, axis=0, ddof=1) / imMean).mean())
                    bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                    var.append((np.var(ims, axis=0, ddof=1)).mean() / imScale)
                    tIters.set_postfix(
                        # N.B: mseTest will be nonzero due to var(..., ddof=1)
                        # mseTest = ((ims - imGnd) ** 2).mean() / imScale - bsq[-1] - var[-1],
                        mse=bsq[-1] + var[-1], bsq=bsq[-1], var=var[-1], refresh=False)
                bsqDict[':'.join(parts[0])] = bsq
                varDict[':'.join(parts[0])] = var

In [None]:
# biasVar
# biasVar-flat
if 0:
    with open("output/biasVar-flat.pkl", "w") as fd:
        pickle.dump((bsqDict, varDict), fd, -1)

In [None]:
print(sorted(varDict.keys()))
rmPatCnt = []
for i1, i3, i5 in {(i[1], i[3], i[5]) for k in sorted(varDict.keys()) for i in [k.split(":")]}:
    avgBiasVar = [
        np.array([src[k] for k in src for i in [k.split(':')] if i[1]==i1 and i[3]==i3 and i[5]==i5]).mean(axis=0)
        for src in (bsqDict, varDict)]
    rmPatCnt.append((i1, i3, i5, avgBiasVar))
    #rmPatCnt.setdefault(i1, {}).setdefault(i3, {})[float(i5)] = res.mean(axis=0)

In [None]:
pats = sorted({i[1] for i in rmPatCnt})
for pat in pats:
  plt.figure()
  plt.title("Train" if pat == "1" else "Validation")
  for rm, patM, cnt, (bsq, var) in rmPatCnt:
    mcnt = float(cnt) / 1e6
    if patM == pat:
      #ls = '+' if "psf" in rm else 'x'
      ls = ':' if mcnt <= 43 else '-'
      y, x = [i ** 0.5 * 100 for i in (bsq, var)]
      #y = x  # std
      y = np.hypot(y, x)  # NRMSE
      i = y.argmin()  # index of minimal NRMSE
      #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
      #y = 10 * np.log10(y / 100.0)
      plt.plot(range(len(x)), y, ls,
               label="MLEM{psf} {mcnt:.0f}M count (min NRMSE at {i}/{nit} iters)".format(
                   psf="+RM" if "psf" in rm else "", mcnt=mcnt, i=i + 1, nit=len(y)))
      plt.plot([i], y[i:i+1], 'ko', ms=6)
  plt.xlabel(r"Iteration")
  plt.ylabel("NRMSE/[%]")
  plt.ylim(0, 200)
  plt.xlim(0, None)
  plt.legend()

In [None]:
import os
from os import path

rImgDir = path.join("output", "maskThresh", str(maskThresh))
if not path.exists(rImgDir):
  os.mkdir(rImgDir)

for k1 in sorted(varDict.keys(), key=lambda x: x.split(':')):
 if ':3.01e+08:' in k1:
  plt.figure()
  plt.title(k1)
  for k, ls in zip([k1, k1.replace(':3.01e+08:', ':4.3e+07:')], ['x-', 'o-']):
    y, x = [np.array(d[k]) ** 0.5 * 100 for d in (bsqDict, varDict)]
    i = (x + y).argmin()  # index of minimal NRMSE
    # plt.title("MLEM {k[5]} count, {k[9]} iters (min NRMSE at {i})".format(k=k.split(':'), i=i + 1))
    #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
    #y = 10 * np.log10(y / 100.0)
    meta = k.split(':')
    meta[1] = "+RM" if meta[1][3:] else ""
    meta[5] = float(meta[5]) / 1e6
    plt.plot(x, y, ls, label="MLEM{k[1]} {k[5]:.0f}M count (min NRMSE at {i}/{k[9]} iters)".format(
        k=meta, i=i + 1))
    plt.plot(x[i:i+1], y[i:i+1], 'ro', ms=6)
  plt.xlabel(r"Standard deviation, $\sigma$/[%]")
  plt.ylabel("Bias/[%]")
  plt.ylim(0, None)
  plt.xlim(0, None)
  #plt.axes().set_aspect('equal')
  plt.legend()
  plt.savefig(path.join(rImgDir, re.sub(r"\W", '', k1)) + ".png")

In [None]:
sorted([k for k in varDict.keys()], key=lambda x: x.split(':'))

In [None]:
#PET_10 = [[r1i1, ..., r1i100], ..., [rNi1, ..., rNi100]]
#PETpsf_10 = ...
#PET_70 = ...
#PETpsf_70 = ...
PET_10 = None

for 

In [None]:
ROI = slice(0, None), slice(100, 250), slice(100, 250)
files = sorted([fn for i in range(FOLDERS) for fn in glob("output/%d/*_300.mat" % i) + glob("%d/*_PET_*_100.mat" % i)])
parts = [RE_INFO.findall(fn)[0] for fn in files]
print(len(parts))
for p in tqdm(parts[:1]):
    p = list(p)
    im = ''.join(p)
    p[fPartsKeys.index("rm")] = fPartsGlob[fPartsKeys.index("rm")]
    p[fPartsKeys.index("counts")] = fPartsGlob[fPartsKeys.index("counts")]
    p[fPartsKeys.index("meta1")] = fPartsGlob[fPartsKeys.index("meta1")]
    p[fPartsKeys.index("it")] = "000"  # 0^th iter (metadata)
    truth = ''.join(p)
    truth = glob(truth)[0]
    print(truth, im)
    # load
    ROI = autoROI(H5Reader(truth).tAct[:])
    print(ROI)
    ##print(H5Reader(truth).tAct.shape)
    truth = imNorm(H5Reader(truth).tAct[ROI])
    ##print(H5Reader(im).Img.shape)
    im = imNorm(H5Reader(im).Img[ROI])
    # plot
    plt.figure()
    msk = truth[63 - ROI[0].start]
    #msk = msk != 0
    #msk = msk > np.percentile(msk[msk > 0].flat, maskThresh)
    msk = flatMask(msk, thresh=40)
    plt.subplot(131); plt.imshow(msk)
    #print(indices)
    plt.subplot(132); plt.imshow(im[63 - ROI[0].start])
    plt.subplot(133); plt.imshow(im[63 - ROI[0].start] * msk)

In [None]:
#glob(glob("6/*_001.mat")[0].replace("_001.mat", "_*.mat"))
#glob(glob("3/*_000.mat")[0].replace("_001.mat", "_*.mat"))
imGnd = glob("output/?/*_000.mat")[0]
ROI = autoROI(H5Reader(imGnd).tAct[:])  # non-zero ROI
ims = glob(glob("output/?/" + imGnd.split('/', 2)[2].replace("_000.mat", "_001.mat"))[0].replace("_001.mat", "*.mat"))

imGnd = imNorm(H5Reader(imGnd).tAct[ROI])  # PET truth
step = 5
ims = [imNorm(H5Reader(i).Img[ROI]) for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [imGnd] + ims

# imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
mask = flatMask(imGnd, maskThresh)
imGnd = imGnd[mask]
imScale = (imGnd ** 2).mean()
l = len(ims)
plt.figure(figsize=(14, 14))
for i, im in enumerate(ims):
    nrmse = ((im[mask] - imGnd) ** 2).mean() / imScale
    plt.subplot(l**.5+1, l**.5+1, i + 1)
    plt.imshow(im[63], origin="lower", cmap="Greys_r")
    plt.axis("off")
    plt.title("it {} NRMSE {:.3g}".format((i - 1) * step + 1, nrmse))

In [None]:
#plt.figure(figsize=(14, 14))
_, axs = plt.subplots(int(l**.5)+1, int(l**.5)+1, figsize=(14, 14), sharex=True, sharey=True)
for i, im in enumerate(tqdm(ims)):
    nrmse = ((im[mask] - imGnd) ** 2).mean() / imScale
    #plt.subplot(l**.5+1, l**.5+1, i + 1)
    plt.sca(axs.flat[i])
    plt.hist(im.flat, bins=50)
    plt.title("it {} NRMSE {:.3g}".format((i - 1) * step + 1, nrmse))

In [None]:
# pick a set of realisations
fn = deepcopy(fPartsGlob)
fn[fPartsKeys.index("rm")] = rms[0]
fn[fPartsKeys.index("pat")] = pats[0]
fn[fPartsKeys.index("counts")] = counts[0]
fn[fPartsKeys.index("tum")] = tums[0]
fn[fPartsKeys.index("it")] = "001"
reals = glob("output/?/" + ''.join(fn))

# get corresponding ground truth
p = deepcopy(fn)
p[fPartsKeys.index("rm")] = fPartsGlob[fPartsKeys.index("rm")]
p[fPartsKeys.index("counts")] = fPartsGlob[fPartsKeys.index("counts")]
#p[fPartsKeys.index("meta1")] = fPartsGlob[fPartsKeys.index("meta1")]
p[fPartsKeys.index("it")] = "000"  # 0^th iter (metadata)
imGnd = H5Reader(glob("output/?/" + ''.join(p))[0]).tAct[ROI]
if False:  # TEST
    assert all([(H5Reader(i).tAct[ROI] == imGnd).all()
                for i in tqdm(glob("output/?/" + ''.join(p)), desc="test")])

# produce mask
mask = imGnd > np.percentile(imGnd[imGnd > imGnd.min()].flat, 99.9)
print(mask.sum(), "pixels")
#plt.subplot(121); plt.imshow(mask[len(mask) // 2])
#plt.subplot(122); plt.imshow(imGnd[len(mask) // 2])

In [None]:
#rizyx = [RE_INFO.findall(i)[0] for i in reals]
rizyx = [[j for j in glob(i.replace("_001.mat", "_*.mat"))
          if not j.endswith("_000.mat")]
         for i in reals]
irzyx = np.array(rizyx).T

for i, rzyx in enumerate(irzyx):
    data = np.array([H5Reader(i).Img[ROI][mask] for i in rzyx])
    plt.hist(data.flat)

In [None]:
#RE_INFO.findall(reals[0])[0]
#rizyx = [i.replace("") for i in sorted(glob(fns)) if "_001.mat" in i]
for it in irxyz:
    plt.hist(it[:][mask].flat)