In [None]:
from __future__ import division
%matplotlib notebook
import matplotlib.pyplot as plt
import tensorflow as tf
tf.keras.backend.set_session(tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
        per_process_gpu_memory_fraction=0.4, visible_device_list="0"))))

import numpy as np
from caspyr.utils import H5Reader, glob
from caspyr.plotting import volshow
from argopt import DictAttrWrap
import logging

logging.basicConfig(level=logging.INFO)

In [None]:
def imshow(im, cmap="Greys", origin="lower", title=None):
    ax = plt.imshow(im, cmap=cmap, origin=origin)
    plt.axis("off")
    if title:
        plt.title(title)

# raw data

# reconstructions

# bias-stdv

In [None]:
nReals = 5
PAD = 4
ROI = (19, 108), (128, 100+115), (120, 100+124)
ROI = [slice(i - PAD, j + PAD) for i, j in ROI]

In [None]:
def load(nPats, nReals=5, counts=None, ROI=None, **kwargs):
    """
    Intended for loading training data, e.g.:
    >>> x, y = load(10, 1, counts=[30, 300])

    @param counts  : [low, high] in millions [default: [3, 300]]
    @param ROI  : ZXY (even though return is ZYX)
    @return ndarray, shape (len(counts), nPats*nReals, D, H, W, Ch)
      Ch is 3=<T1, PSF, PET> if `counts > 0`, 1=<TRUTH> otherwise.
    """
    from caspyr.utils import Globber, H5Reader
    from os import path
    import numpy as np
    import logging
    from tqdm.auto import tqdm
    glob = Globber.glob
    log = logging.getLogger(__name__)

    DAT_ROOT = "/data/cc16/apirl/o"
    #"brainweb_PETpsf_1_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3_{:03d}.mat"
    subjects = sorted({i[47:49] for i in glob(path.join(
        DAT_ROOT, "reconAPIRL_brainweb_subject_*-S_1-NP_1-NT1_0.75-C_*_t3.mat"))})
    log.debug("subj:" + ', '.join(subjects))

    counts = counts or (3.01, 301)
    ROI = ROI or (slice(0, None),) * 3
    assert len(ROI) == 3
    ROI = tuple(ROI)

    res = [] # [counts, nReals + nReals_psf + T1 + PET, z,y,x]
    for c in tqdm(counts, desc="counts"):
        vols = []
        for subj in tqdm(subjects[:nPats], desc="subject", leave=False):
            d = path.join(DAT_ROOT,
                "reconAPIRL_brainweb_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3.mat".format(
                    subj, abs(c)*1e6))
            d = H5Reader(d, prefix="reconAPIRL")
            if c < 0:
                i = d.PET[ROI][None, :,:,:, None] * d.scale_factor[0,0]  # 1zyx1
                i = np.repeat(i, min(d.mlem.shape[3], nReals), axis=-1)  # 1zyxN
            else:
                i = np.concatenate((
                    d.mlem_psf[ROI + (slice(0, nReals),)][None],
                    d.mlem[ROI + (slice(0, nReals),)][None],
                ))  # 2zyxN
                i = np.concatenate((np.repeat(d.T1[ROI][None, :,:,:, None], i.shape[-1], axis=-1), i))  # 3zyxN
                #d.PET[ROI][None, :,:,:,None] * d.scale_factor[0,0]
            i = i.transpose((4, 1,3,2, 0))  # Nzxy3
            vols.append(i)
        res.append(np.concatenate(vols))
    return res

# low->high
#dat = load(10, nReals=nReals, counts=[3.01, 301], ROI=ROI) # 2, nReals*10, z, y, x, 3
#np.savez_compressed('bweb-3.01-301.npz', x=dat[0], y=dat[1])

dat = load(10, nReals=nReals, counts=[3.01, -3.01], ROI=ROI)
np.savez_compressed('bweb-3.01-T.npz', x=dat[0], y=dat[1])

#dat = load(10, nReals=nReals, counts=[30.1, -30.1], ROI=ROI)
#np.savez_compressed('bweb-30.1-T.npz', x=dat[0], y=dat[1])

In [None]:
# bias-var
from argopt import DictAttrWrap
from viper.utils import stats
import functools

IDX = DictAttrWrap(dict(TRUE=0, T1=0, PSF=1, PET=2))
load_model = functools.partial(tf.keras.models.load_model,
    custom_objects=dict(lossGen=tf.keras.losses.mse, nrmse=tf.keras.losses.mse))
modelGen = load_model("modelGen-3.01-T.h5")
modelGAN = load_model("modelGAN-3.01-T.h5")

tru = dat[1][:, :,:,:, IDX.TRUE]
pet = dat[0][:, :,:,:, IDX.PET]
psf = dat[0][:, :,:,:, IDX.PSF]
net = modelGen.predict(dat[0])[:, :,:,:, IDX.TRUE]
gan = modelGAN.predict(dat[0])[:, :,:,:, IDX.TRUE]

#volshow(np.concatenate((tru[:4], psf[:4], pet[:4])));

In [None]:
for i in range(nReals):
    reals = slice(i*nReals, (i + 1)*nReals)
    #volshow(tru[reals])

    """
    scale = 1 / (tru[reals][0] ** 2).mean()
    bias = (((pet[reals].mean(axis=0) - tru[reals][0]) ** 2).mean() * scale) ** 0.5
    stdv = ((pet[reals].var(axis=0, ddof=1)).mean() * scale) ** 0.5
    # NOTE: use ddof=0 above to make nrmse = hypot(bias, stdv)
    #nrmse = (((pet[reals] - tru[reals][0]) ** 2).mean() * scale) ** 0.5
    """
    bias, stdv = stats.biasStdMask(pet[reals], tru[reals][:1])
    print("pet", bias, stdv, np.hypot(bias, stdv))

    bias, stdv = stats.biasStdMask(psf[reals], tru[reals][:1])
    print("psf", bias, stdv, np.hypot(bias, stdv))

    bias, stdv = stats.biasStdMask(net[reals] * tru[reals][:1].std(), tru[reals][:1])
    print("net", bias, stdv, np.hypot(bias, stdv))

    bias, stdv = stats.biasStdMask(gan[reals] * tru[reals][:1].std(), tru[reals][:1])
    print("gan", bias, stdv, np.hypot(bias, stdv))

In [None]:
def biasStdv(nPats, nReals=5, counts=3.01, ROI=None, **kwargs):
    """
    >>> biasStdv(10, 1, counts=[30, 300])

    @param counts  : in millions
    @param ROI  : ZXY (even though return is ZYX)
    @return ndarray, shape ((MLEM, MLEM+PS, PSF, PSF+PS), (bias, stdv), nItr)
    """
    from caspyr.utils import Globber, H5Reader
    from viper.utils.stats import biasStdMask  # biasStd as biasStdMask
    from viper.constants import SIGMA2FWHM_MMRzyx
    from viper.imsample import gauss
    from os import path
    import numpy as np
    import logging
    from tqdm.auto import tqdm, trange

    glob = Globber.glob
    log = logging.getLogger(__name__)

    DAT_ROOT = "/data/cc16/apirl/o"
    #"brainweb_PETpsf_1_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3_{:03d}.mat"
    subjects = sorted({i[47:49] for i in glob(path.join(
        DAT_ROOT, "reconAPIRL_brainweb_subject_*-S_1-NP_1-NT1_0.75-C_*_t3.mat"))})
    log.debug("subj:" + ', '.join(subjects))

    c = counts * 1e6
    ROI = ROI or (slice(0, None),) * 3
    assert len(ROI) == 3
    ROI = tuple(ROI)

    res = []  # [nPat, [MLEM, MLEM+PS, PSF, PSF+PS], nItr, [bias, stdv]]
    for subj in tqdm(subjects[:nPats], unit="subject", leave=True):
        tru = path.join(DAT_ROOT,
                "reconAPIRL_brainweb_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3.mat".format(
                    subj, c))
        tru = H5Reader(tru, prefix="reconAPIRL")
        tru = tru.PET[ROI][None] * tru.scale_factor[0,0]  # 1zxy

        iters = glob(path.join(DAT_ROOT,
            "brainweb_PET_0_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3_*.mat".format(
                subj, c)))
        iters = [i[-7:-4] for i in iters]
        log.debug("itr:" + ' '.join(iters))

        stage = tqdm(total=2, desc="Stage", leave=False)

        bs = []
        bsPSF = []
        for i in tqdm(iters, desc="Iterations", leave=False):
            reals = glob(path.join(DAT_ROOT,
                "brainweb_PET_*_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3_{}.mat".format(
                    subj, c, i)))
            #log.debug(reals)
            reals = [H5Reader(d).Img[ROI] for d in reals]  # Nzxy
            bs.append(biasStdMask(np.array(reals), tru))

            realsPSF = glob(path.join(DAT_ROOT,
                "brainweb_PETpsf_*_subject_{}-S_1-NP_1-NT1_0.75-C_{:.3g}_t3_{}.mat".format(
                    subj, c, i)))
            realsPSF = [H5Reader(d).Img[ROI] for d in realsPSF]  # Nzxy
            bsPSF.append(biasStdMask(np.array(realsPSF), tru))
        #res.append(bs)
        stage.update()

        bsPS = []
        bsPSF_PS = []
        last = np.array(reals)
        lastPSF = np.array(realsPSF)
        for i in tqdm(np.linspace(0, 25, num=len(iters)), unit_scale=25.0/len(iters),
                      unit="mm", desc="PS", leave=False):
            sigma = [i / s for s in SIGMA2FWHM_MMRzyx]
            bsPS.append(biasStdMask(np.array([gauss(i, sigma) for i in last]), tru))
            bsPSF_PS.append(biasStdMask(np.array([gauss(i, sigma) for i in lastPSF]), tru))
        stage.update()

        stage.close()
        res.append((bs, bsPS, bsPSF, bsPSF_PS))
    return np.mean(res, axis=0).transpose((0, 2, 1))

logging.getLogger().setLevel(logging.INFO)
# low
bsPet = biasStdv(1, nReals=nReals, counts=3.01, psf=False, ROI=ROI)
# high
#bias, var = biasStdv(10, nReals=nReals, counts=301, ROI=ROI)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(bsPet[0, 1], bsPet[0, 0], 'o-', ms=4, label="MLEM")
ax.plot(bsPet[2, 1], bsPet[2, 0], 'o-', ms=4, label="MLEM+RM")

i = np.hypot(bsPet[1, 1], bsPet[1, 0]); i = np.where(i == i.min())[0]
ax.plot(bsPet[1, 1], bsPet[1, 0], 'ko-', markevery=i, label="PS up to 25mm")

i = np.hypot(bsPet[3, 1], bsPet[3, 0]); i = np.where(i == i.min())[0]
ax.plot(bsPet[3, 1], bsPet[3, 0], 'ko-', markevery=i)

plt.xlabel(r"Standard deviation, $\sigma$/[%]")
plt.ylabel(r"Bias, $b$/[%]")
plt.xlim(0, None)
plt.legend();