In [None]:
from __future__ import print_function, division
%matplotlib notebook
import matplotlib.pyplot as plt

import re
from copy import deepcopy
import numpy as np
import cPickle as pickle
import re
from os import path

import warnings
with warnings.catch_warnings():
    from tqdm import TqdmExperimentalWarning
    warnings.simplefilter("ignore", category=TqdmExperimentalWarning)
    from tqdm.auto import tqdm, trange
warnings.filterwarnings("ignore", category=RuntimeWarning, message="numpy.dtype size changed", module="h5py")
from caspyr.utils import H5Reader

In [None]:
def imshow(im, cmap="Greys", origin="lower", title=None, **kwargs):
    ax = plt.imshow(im, cmap=cmap, origin=origin, **kwargs)
    plt.axis("off")
    if title:
        plt.title(title)

def autoROI(im):
    """return bounding box tuple(slice) of non-zero elements"""
    inds = np.array(im.nonzero())
    mins = inds.min(axis=1)
    maxs = inds.max(axis=1)
    return tuple(slice(i, j) for i, j in zip(mins, maxs))

def imNorm(im, assumeZeroMin=True):
    """in-place, average 1 count/voxel"""
    if not assumeZeroMin:
        im -= im.min()
    im *= im.size / im.max()
    return im

def flatMask(im, thresh=50, hw=1):
    """
    thresh  : int, percentile
    hw  : int, half width
    """
    from scipy.signal import convolve
    m = np.ones([2*hw+1] * im.ndim) * -1
    m[hw, hw] = m.size - 1
    edges = np.abs(convolve(im, m, mode='same'))
    msk = edges < np.percentile(edges.flat, thresh)  # high pass
    msk *= im > np.percentile(im[im > 0].flat, 100 - thresh)  # low pass
    return msk

class Globber(object):
    def __init__(self, *parts):
        """
        parts  : (tuple(name, regex, glob), ...)
        """
        self.partsKeys = [i[0] for i in parts]
        self.partsVals = [i[1] for i in parts]
        self.partsGlob = [i[2] for i in parts]
        self.val = self.parseRe()
        self.valRe = re.compile(self.val)
        self.glb = self.parseGlob()
        # glob all files
        self.files = self.glob(self.glb)
        # ensure files pass the regex
        self.files = self.filterRe()

    def filterGlob(self, **keys):
        """return self.files that match given key=glob pairs"""
        raise DeprecationWarning
        p = deepcopy(self.partsGlob)
        for k, v in keys.items():
            p[self.index(k)] = v
        return self.filterRe(self.glob(self.parseGlob(p)))

    def filterRe(self, **keys):
        """return self.files that match given key=regex pairs"""
        p = deepcopy(self.partsVals)
        for k, v in keys.items():
            p[self.index(k)] = v
        r = re.compile(self.parseRe(p))
        return filter(r.match, self.files)

    def filterReLike(self, fname, **keys):
        """return self.files that are like `fname` but with given key=regex pairs

        Use `None` for regex for default (all)"""
        p = self.valRe.findall(fname)[0]
        for k, v in keys.items():
            p[self.index(k)] = self.partsVals[self.index(k)] if v is None else v
        r = re.compile(self.parseRe(p))
        return filter(r.match, self.files)

    def index(self, key):
        return self.partsKeys.index(key)

    def parseGlob(self, parts=None):
        return ''.join(self.partsGlob if parts is None else parts)

    def parseRe(self, parts=None):
        return ''.join('(' + i + ')' for i in
                       (self.partsVals if parts is None else parts))

    @classmethod
    def glob(cls, *a, **k):
        """sorted version of glob.glob"""
        from glob import glob as g
        return sorted(g(*a, **k))

def glob(*a, **k):
    return Globber.glob(*a, **k)

In [None]:
globber = Globber(
    ("meta0", "output/[0-9]+/brainweb_", "output/*/brainweb_"),
    ("rm", "PET|PETpsf", "PET*"),
    ("meta1", "_[0-9]+_subject_", "_*_subject_"),
    ("pat", "[0-9]+", "*"),
    ("meta6", "-S_", "-S_"),
    ("sigma", "[0-9.e+]+", "*"),
    ("meta7", "-NP_", "-NP_"),
    ("fsPET", "[0-9.e+]+", "*"),
    ("meta8", "-NT1_", "-NT1_"),
    ("fsT1", "[0-9.e+]+", "*"),
    ("meta2", "-C_", "-C_"),
    ("counts", "[0-9.e+]+", "*"),
    ("meta3", "_t", "_t"),
    ("tum", "[-0-9]+", "*"),
    ("meta4", "_", "_"),
    ("it", "[0-9]+", "*"),
    ("meta5", r"\.mat", ".mat"))

In [None]:
# name, regex, glob
fParts = (
    ("meta0", ".*brainweb_", "*brainweb_"),
    ("rm", "PET|PETpsf", "PET*"),
    ("meta1", "_[0-9]+_subject_", "_*_subject_"),
    ("pat", "[0-9]+", "*"),
    ("meta6", "-S_", "-S_"),
    ("sigma", "[0-9.e+]+", "*"),
    ("meta7", "-NP_", "-NP_"),
    ("fsPET", "[0-9.e+]+", "*"),
    ("meta8", "-NT1_", "-NT1_"),
    ("fsT1", "[0-9.e+]+", "*"),
    ("meta2", "-C_", "-C_"),
    ("counts", "[0-9.e+]+", "*"),
    ("meta3", "_t", "_t"),
    ("tum", "[-0-9]+", "*"),
    ("meta4", "_", "_"),
    ("it", "[0-9]+", "*"),
    ("meta5", r"\.mat", ".mat"))
fPartsKeys = [i[0] for i in fParts]
fPartsVals = [i[1] for i in fParts]
fPartsGlob = [i[2] for i in fParts]
RE_INFO = ''.join('(' + i + ')' for i in fPartsVals)
RE_INFO = re.compile(RE_INFO)
FOLDERS = 8
maskThresh = 50

In [None]:
def getGnd(fn, debug=False):
    # ground truth image
    p = list(RE_INFO.findall(fn)[0])
    p[fPartsKeys.index("rm")] = fPartsGlob[fPartsKeys.index("rm")]
    #p[fPartsKeys.index("counts")] = fPartsGlob[fPartsKeys.index("counts")]
    p[fPartsKeys.index("meta1")] = fPartsGlob[fPartsKeys.index("meta1")]
    p[fPartsKeys.index("it")] = "000"  # 0^th iter (metadata)
    for c in "0123456789":
      p[fPartsKeys.index("meta0")] = p[fPartsKeys.index("meta0")].replace("/%s/" % c, "/*/")
    p = ''.join(p)
    files = glob(p)
    if not files:
        raise IOError("Could not find:" + p)
    if debug:
        ROI = autoROI(H5Reader(files[0]).tAct[:])  # non-zero ROI
        imGnd = H5Reader(files[0]).tAct[ROI]
        tmp = [(imGnd == H5Reader(i).tAct[ROI]).all() for i in tqdm(files, desc="debug")]
        #print(len(tmp), files)
        if not all(tmp):
            raise ValueError("mismatched ground truths:" + p)
    return files[0]

In [None]:
files = sorted([fn for i in range(FOLDERS) for fn in glob("output/%d/brain*_001.mat" % i)])
parts = [RE_INFO.findall(fn)[0] for fn in files]

rms = sorted({i[fPartsKeys.index("rm")] for i in parts})
pats = sorted({i[fPartsKeys.index("pat")] for i in parts}, key=int)
counts = sorted({i[fPartsKeys.index("counts")] for i in parts}, key=float)
tums = sorted({i[fPartsKeys.index("tum")] for i in parts}, key=int)
its = sorted({i[fPartsKeys.index("it")] for i in parts}, key=int)
print(rms, pats, counts, tums)  # , its

In [None]:
varDict = {}
bsqDict = {}

In [None]:
#outputPkl = "output/biasVar.pkl"
#outputPkl = "output/biasVar-flat.pkl"
outputPkl = "output/biasVar-brainweb.pkl"
with open(outputPkl) as fd:
    bsqDict, varDict = pickle.load(fd)

In [None]:
#ROI = slice(0, None), slice(100, 250), slice(100, 250)
fn = deepcopy(fPartsGlob)
for rm in tqdm(rms, desc="RM"):
    fn[fPartsKeys.index("rm")] = rm
    for pat in tqdm(pats, desc="Patient"):
        fn[fPartsKeys.index("pat")] = pat
        for cnt in tqdm(counts, desc="Counts"):
            fn[fPartsKeys.index("counts")] = cnt
            for tum in tqdm(tums, desc="Tumours"):
                fn[fPartsKeys.index("tum")] = tum
                files = glob("output/[%s]/%s" % (''.join(map(str, range(FOLDERS))), ''.join(fn)))
                if not files:
                    continue
                #print(len(files))
                #print('\n'.join(files))
                parts = [RE_INFO.findall(i)[0] for i in files]
                #print(len(parts))
                parts = sorted({i[:-2] + ("",) + i[-1:] for i in parts})  # ignore iteration
                print('\n'.join([''.join(i) for i in parts]))
                parts = [list(p) for p in parts]
                #print('\n'.join(':'.join(i) for i in parts))
                # per-iteration results vectors
                bsq = []
                var = []
                # ground truth image
                imGnd = H5Reader(getGnd(files[0]))
                ROI = autoROI(imGnd.tAct[:])  # non-zero ROI
                imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
                print("counts:%.3g" % imGnd.sum())
                #imGnd = imNorm(imGnd)
                # imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
                ####mask = flatMask(imGnd, maskThresh)
                mask = imGnd > 0
                imGnd = imGnd[mask]
                imScale = (imGnd ** 2).mean()
                # pre-allocate one set of realisations [R, X, Y, Z]
                #ims = np.zeros((len(parts),) + imGnd.shape, dtype=imGnd.dtype)
                ims = np.zeros((len(parts), imGnd.size), dtype=imGnd.dtype)
                with trange(1, 1 + (300 if "psf" in rm else 100),
                                 desc="%d reals iter" % len(parts)) as tIters:
                  for it in tIters:
                    for i, p in enumerate(tqdm(parts, desc="realisations", disable=True)):
                        p[-2] = "%03d" % it
                        #ims[i] = imNorm(H5Reader(''.join(p)).Img[ROI])[mask]
                        ims[i] = H5Reader(''.join(p)).Img[ROI][mask]
                    imMean = ims.mean(axis=0)
                    #bsq.append(((imMean - imGnd) / imGnd).mean())
                    #var.append((np.std(ims, axis=0, ddof=1) / imMean).mean())
                    bsq.append(((imMean - imGnd) ** 2).mean() / imScale)
                    var.append((np.var(ims, axis=0, ddof=1)).mean() / imScale)
                    tIters.set_postfix(
                        # N.B: mseTest will be nonzero due to var(..., ddof=1)
                        # mseTest = ((ims - imGnd) ** 2).mean() / imScale - bsq[-1] - var[-1],
                        mse=bsq[-1] + var[-1], bsq=bsq[-1], var=var[-1], refresh=False)
                bsqDict[':'.join(parts[0])] = bsq
                varDict[':'.join(parts[0])] = var

In [None]:
if 0:
    with open(outputPkl, "w") as fd:
        pickle.dump((bsqDict, varDict), fd, -1)

In [None]:
print('\n'.join(sorted(varDict.keys())))
rmPatCnt = []
for rm, pat, cnt in {(i[fPartsKeys.index("rm")], i[fPartsKeys.index("pat")], i[fPartsKeys.index("counts")])
                   for k in varDict.keys() for i in [k.split(":")]}:
    avgBiasVar = [
        np.array([src[k] for k in src for i in [k.split(':')]
                  if i[fPartsKeys.index("rm")]==rm
                  and i[fPartsKeys.index("pat")]==pat
                  and i[fPartsKeys.index("counts")]==cnt]).mean(axis=0)
        for src in (bsqDict, varDict)]
    rmPatCnt.append((rm, pat, cnt, avgBiasVar))
    #rmPatCnt.setdefault(i1, {}).setdefault(i3, {})[float(i5)] = res.mean(axis=0)
rmPatCnt.sort()

In [None]:
pats = sorted({i[1] for i in rmPatCnt})
for pat in pats:
  plt.figure()
  plt.title("Train" if pat == pats[0] else "Validation")
  for rm, patM, cnt, (bsq, var) in rmPatCnt:
    mcnt = float(cnt) / 1e6
    if patM == pat:
      #ls = '+' if "psf" in rm else 'x'
      ls = ':' if mcnt <= 43 else '-'
      #nbias, nstd = [i ** 0.5 * 100 for i in (bsq, var)]
      y = (np.array(bsq) + var) ** 0.5 * 100  # NRMSE
      i = y.argmin()  # index of minimal NRMSE
      #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
      #y = 10 * np.log10(y / 100.0)
      plt.plot(range(len(y)), y, ls,
               label="MLEM{psf} {mcnt:.0f}M count (min NRMSE at {i}/{nit} iters)".format(
                   psf="+RM" if "psf" in rm else "", mcnt=mcnt, i=i + 1, nit=len(y)))
      plt.plot([i], y[i:i+1], 'ko', ms=6)
  plt.xlabel(r"Iteration")
  plt.ylabel("NRMSE/[%]")
  plt.ylim(0, 200)
  plt.xlim(0, None)
  plt.legend()

In [None]:
import os
from os import path

rImgDir = path.join("output", "maskThresh", str(maskThresh))
if not path.exists(rImgDir):
  os.mkdir(rImgDir)

for k1 in sorted(varDict.keys(), key=lambda x: x.split(':')):
 if ':3.01e+08:' in k1:
  plt.figure()
  plt.title(k1)
  for k, ls in zip([k1, k1.replace(':3.01e+08:', ':4.3e+07:')], ['x-', 'o-']):
    y, x = [np.array(d[k]) ** 0.5 * 100 for d in (bsqDict, varDict)]
    i = (x + y).argmin()  # index of minimal NRMSE
    # plt.title("MLEM {k[5]} count, {k[9]} iters (min NRMSE at {i})".format(k=k.split(':'), i=i + 1))
    #plt.ticklabel_format(style="sci", scilimits=(-3, 3))
    #y = 10 * np.log10(y / 100.0)
    meta = k.split(':')
    meta[fPartsKeys.index("rm")] = "+RM" if meta[fPartsKeys.index("rm")][3:] else ""
    meta[fPartsKeys.index("counts")] = float(meta[fPartsKeys.index("counts")]) / 1e6
    plt.plot(x, y, ls, label=("MLEM{k[" +
                              str(fPartsKeys.index("rm")) +
                              "]} {k[" +
                              str(fPartsKeys.index("counts")) +
                              "]:.0f}M count (min NRMSE at {i}/{k[" +
                              str(fPartsKeys.index("it")) +
                              "]} iters)").format(
        k=meta, i=i + 1))
    plt.plot(x[i:i+1], y[i:i+1], 'ro', ms=6)
  plt.xlabel(r"Standard deviation, $\sigma$/[%]")
  plt.ylabel("Bias/[%]")
  plt.ylim(0, None)
  plt.xlim(0, None)
  #plt.axes().set_aspect('equal')
  plt.legend()
  plt.savefig(path.join(rImgDir, re.sub(r"\W", '', k1)) + ".png")

In [None]:
sorted([k for k in varDict.keys()], key=lambda x: x.split(':'))

In [None]:
files = sorted([fn for i in range(FOLDERS)
                for fn in glob("output/%d/*_300.mat" % i) + glob("output/%d/*_PET_*_100.mat" % i)])
parts = [RE_INFO.findall(fn)[0] for fn in files]
print(len(parts))
for p in tqdm(parts[:1]):
    im = ''.join(p)
    truth = getGnd(im, debug=True)
    print(truth, '\n', im)
    # load
    ROI = autoROI(H5Reader(truth).tAct[:])
    print(ROI)
    ##print(H5Reader(truth).tAct.shape)
    truth = H5Reader(truth)
    truth = truth.tAct[ROI] * truth.scale_factor[0, 0]
    #truth = imNorm(truth)
    ##print(H5Reader(im).Img.shape)
    im = H5Reader(im).Img[ROI]
    #im = imNorm(im)
    # plot
    plt.figure()
    msk = truth[63 - ROI[0].start]
    #msk = msk != 0
    #msk = msk > np.percentile(msk[msk > 0].flat, maskThresh)
    msk = flatMask(msk, thresh=40)
    plt.subplot(131); imshow(msk)
    #print(indices)
    plt.subplot(132); imshow(im[63 - ROI[0].start])
    plt.subplot(133); imshow(im[63 - ROI[0].start] * msk)

In [None]:
im = sorted(glob("output/*/*_0_*_001.mat"), key=path.getmtime)[-1]

#imGnd = glob("output/?/*e+08*_000.mat")[-1]
imGnd = getGnd(im)
ROI = autoROI(H5Reader(imGnd).tAct[:])  # non-zero ROI
ims = im[:-7] + "*.mat"
print(ims)
ims = glob(ims)
assert all([im[:-7] == i[:-7] for i in ims])

# PET truth
imGnd = H5Reader(imGnd)
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]
#imGnd = imNorm(imGnd)
step = 1
print([i[-7:-4] for i in ims[::step]])
#ims = [imNorm(H5Reader(i).Img[ROI]) for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [H5Reader(i).Img[ROI] for i in tqdm(ims[::step]) if not i.endswith("_000.mat")]
ims = [imGnd] + ims

# imGnd >= np.percentile(imGnd[imGnd != 0].flat, maskThresh)
mask = flatMask(imGnd, maskThresh)
imGnd = imGnd[mask]
imScale = (imGnd ** 2).mean()
l = len(ims)
plt.figure(figsize=(14, 14))
for i, im in enumerate(ims):
    nrmse = (((im[mask] - imGnd) ** 2).mean() / imScale) ** 0.5
    plt.subplot(l**.5+1, l**.5+1, i + 1)
    imshow(im[63], vmin=0, vmax=imGnd.max(),
           title="{:03d} {:.3g}".format((i - 1) * step + 1, nrmse))

In [None]:
# pick a set of realisations
fn = deepcopy(fPartsGlob)
fn[fPartsKeys.index("rm")] = rms[0]
fn[fPartsKeys.index("pat")] = pats[0]
fn[fPartsKeys.index("counts")] = counts[1]
fn[fPartsKeys.index("tum")] = tums[0]
fn[fPartsKeys.index("it")] = "001"
reals = glob("output/?/" + ''.join(fn))

# get corresponding ground truth
imGnd = H5Reader(getGnd(reals[-1]))
imGnd = imGnd.tAct[ROI] * imGnd.scale_factor[0, 0]

# produce mask
#mask = imGnd > np.percentile(imGnd[imGnd > imGnd.min()].flat, 99.9)
mask = flatMask(imGnd)
print(mask.sum(), "pixels")
#plt.subplot(121); imshow(mask[len(mask) // 2])
#plt.subplot(122); imshow(imGnd[len(mask) // 2])

In [None]:
#rizyx = [RE_INFO.findall(i)[0] for i in reals]
rizyx = [[j for j in glob(i.replace("_001.mat", "_*.mat"))
          if not j.endswith("_000.mat")]
         for i in reals]
irzyx = np.array(rizyx).T

#RE_INFO.findall(reals[0])[0]
#rizyx = [i.replace("") for i in glob(fns) if "_001.mat" in i]

step = 5
irzyx = irzyx[::step]

l = len(irzyx)
_, axs = plt.subplots(int(l**.5)+1, int(l**.5)+1, figsize=(14, 14), sharex=True, sharey=True)
for i, it in enumerate(tqdm(irzyx)):
    plt.sca(axs.flat[i])
    plt.hist(np.array([H5Reader(im).Img[ROI][mask] for im in it]).flat)
    plt.title(str((i - 1) * step + 1))