In [None]:
import numpy
import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path
from scipy.ndimage import gaussian_filter
from scipy.stats import pearsonr
from tifffile import imread
from time import perf_counter


import sys
sys.path.append("/g/kreshuk/beuttenm/repos/lnet")
from lnet.utils.plotting import turbo_colormap 

In [None]:
def get_raw_data(name):
    paths = sorted(list(Path(name).glob("*.tif")))
    data = []
    for path in paths:
        data.append((imread(path.as_posix()) / numpy.iinfo(numpy.uint16).max).astype(numpy.float32))
    
    return numpy.asarray(data)

In [None]:
pred = get_raw_data("/g/kreshuk/beuttenm/repos/lensletnet/logs/platy/3-100-500/1/19-09-03_09-14_63d6439_m12dout-bc/result/test/prediction")

In [None]:
print(pred.shape)

In [None]:
tgt = get_raw_data("/g/kreshuk/beuttenm/repos/lensletnet/logs/platy/3-100-500/1/19-09-03_09-14_63d6439_m12dout-bc/result/test/target")

In [None]:
print(tgt.shape)

In [None]:
plt.imshow(pred[0 ,0].max(0))

In [None]:
plt.imshow(tgt[0 ,0].max(0))

In [None]:
def bin(data):
    shape = data.shape
    assert len(shape) == 5
    assert shape[1] == 1
    assert all([os % 10 == 0 for os in shape[3:]])
    data = data[:, 0]
    
    out_shape = numpy.asarray(data.shape)
    out_shape[0:2] -= 2
    out_shape[2:] //= 10
    assert all([os > 0 for os in out_shape])
    
    step = 10
    xy = numpy.mean(numpy.stack([data[:, :, x0::step, y0::step] for x0 in range(step) for y0 in range(step)]), axis=0)
    z = numpy.sum(numpy.stack([0.25* xy[:, :-2], 0.5 * xy[:, 1:-1], 0.25 * xy[:, 2:]]), axis=0)
    t = numpy.sum(numpy.stack([0.25* z[:-2], 0.5 * z[1:-1], 0.25 * z[2:]]), axis=0)
    
    return t

In [None]:
pred_binned = bin(pred[:, :, :, 30:, 76:])
print('shape', pred_binned.shape)
plt.imshow(pred_binned[0].max(axis=0))

In [None]:
plt.imshow(tgt[: ,0, :, 30:, 76:].max(0).max(0), cmap=turbo_colormap)
plt.colorbar()
plt.savefig("test.svg")

In [None]:
plt.imshow(tgt[: ,0, :, 30:, 76:].std(0).max(0), cmap=turbo_colormap)
plt.colorbar()
# plt.savefig("/g/kreshuk/beuttenm/Documents/for_oc/std.svg")

In [None]:
tgt_binned = bin(tgt[:, :, :, 30:, 76:])
tgt_binned.shape

In [None]:
roi = (slice(None), 13, 10, 10)
print(pearsonr(pred_binned[roi], tgt_binned[roi]))

In [None]:
start = perf_counter()
prs = numpy.empty(tgt_binned.shape[1:], dtype=numpy.float32)
for z in range(tgt_binned.shape[1]):
    for y in range(tgt_binned.shape[2]):
        for x in range(tgt_binned.shape[3]):
            prs[z, y, x] = pearsonr(pred_binned[:, z, y, x], tgt_binned[:, z, y, x])[0]

print(perf_counter() - start)

In [None]:
# Save Data
numpy.save("pred_binned", pred_binned)
numpy.save("tgt_binned", tgt_binned)
numpy.save("prs", prs)

In [None]:
start = perf_counter()
prs300 = numpy.empty(tgt_binned.shape[1:], dtype=numpy.float32)
for z in range(tgt_binned.shape[1]):
    for y in range(tgt_binned.shape[2]):
        for x in range(tgt_binned.shape[3]):
            prs300[z, y, x] = pearsonr(pred_binned[:300, z, y, x], tgt_binned[:300, z, y, x])[0]

print(perf_counter() - start)

In [None]:
plt.imshow(prs[13])
plt.colorbar()

In [None]:
prs.max()

In [None]:
prs300.max()

In [None]:
stds = tgt_binned.std(axis=0)
stds.shape

In [None]:
stds.max()

In [None]:
numpy.percentile(stds, 80)

In [None]:
numpy.median(stds)

In [None]:
stds.mean()

In [None]:
mask = stds > numpy.percentile(stds, 80)
masked_prs = numpy.array(prs)
masked_prs[~mask] = numpy.nan
print('mean masked pr', numpy.nanmean(masked_prs))
plt.imshow(mask.max(axis=0))

In [None]:
msum = mask.sum()
print(msum, msum / numpy.prod(mask.shape))

In [None]:
show = numpy.array(tgt_binned[0])
show[~mask] = numpy.nan
plt.imshow(numpy.nanmax(show, axis=0))

In [None]:
def plot_masked(data, mask):
    data = numpy.array(data)
    data[~mask] = numpy.nan
    plt.imshow(numpy.nanmax(data, axis=0))
    plt.colorbar()

In [None]:
plot_masked(prs, mask)
plt.show()

In [None]:
def plot_ts(tgt, pred, name=""):
    fig, ax_tgt = plt.subplots()
    ax_pred = ax_tgt.twinx()
    ax_tgt.set_xlabel("time [s]")
    ax_tgt.set_ylabel("RL reconstruction [$\Delta F/F_0$]")
    ax_pred.set_ylabel("Network prediction [$\Delta F/F_0$]")
    
    t = numpy.arange(tgt.shape[0]) * 0.11
    tgt = numpy.array(tgt)
    pred = numpy.array(pred)
    add_label = ""
    lns = ax_tgt.plot(t, tgt, color="r", label="RL reconstruction" + add_label)
    lns += ax_pred.plot(t, pred, color="b", label="Network prediction" + add_label) # + f" {pcoeff:.2f}

    labs = [l.get_label() for l in lns]
    ax_tgt.legend(lns, labs)
    if name:
        plt.savefig(f"/g/kreshuk/beuttenm/Documents/for_oc/{name}.pgf")
    
    plt.show()
    
    
def save_idx(idx, name):
    t_max = None
    plot_ts(tgt_binned[:, mask][:t_max, idx], pred_binned[:, mask][:t_max, idx], name=name)

In [None]:
# prsidx = numpy.where(prs[mask])[0]
numpy.argsort(prs[mask])
idxs = numpy.argsort(prs[mask])[::-1]

In [None]:
save_idx(idxs[-1], name="out")
idxs[-1]

In [None]:
for i in idx[:10]:
    #     if tgt_binned[:, mask][100, i] < .2:
    #         continue 

    print(i)
    save_idx(i, name="out")

In [None]:
for idx in range(0, 1000, 100):
    plot_ts(tgt_binned[:, mask][:, idx], pred_binned[:, mask][:, idx])

In [None]:
for idx in range(1, 1000, 100):
    plot_ts(tgt_binned[:, mask][:, idx], pred_binned[:, mask][:, idx])

In [None]:
for idx in range(2, 1000, 100):
    plot_ts(tgt_binned[:, mask][:, idx], pred_binned[:, mask][:, idx])

In [None]:
t_max = 300
cnt = 0
for idx in range(0, 1000, 10):
    selected_tgt = tgt_binned[:t_max, mask][:, idx]
    selected_pred = pred_binned[:t_max, mask][:, idx]
    
    if selected_tgt[100] < .2:
        continue

    cnt += 1
    if cnt > 10: 
        break
    plot_ts(selected_tgt, selected_pred)

In [None]:
t_max = 400
cnt = 0
for idx in range(0, 1000, 10):
    selected_tgt = tgt_binned[:t_max, mask][:, idx]
    selected_pred = pred_binned[:t_max, mask][:, idx]
    
    if selected_tgt[100] < .2:
        continue

    cnt += 1
    if cnt > 10: 
        break
    plot_ts(selected_tgt, selected_pred)

In [None]:
psnr

In [None]:
def plot_hist(prs):
    fig, ax = plt.subplots()
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position("right")
    ax.set_xlim(xmin=0)
    ax.set_xlabel("Pearson Correlation Coefficient")
    ax.set_ylabel("counts")

    plt.hist(prs, bins=25)
    
    

In [None]:
plot_hist(prs[mask])
plt.savefig("hist.svg")

In [None]:
prs[mask].mean()

In [None]:
plot_hist(prs.flatten())

In [None]:
prs.mean()

In [None]:

selected_masked = [1256, 1963, 1688, 2365] 

In [None]:
x, y, z = numpy.where(mask)
for s in selected_masked:
    print(s)
    print(z[s], y[s], x[s])
    
    print('here', s, y[s] / 28, x[s] / 21)
    
print(mask.shape)

In [None]:
for s in selected_masked:
    save_idx(s, name=str(s))

In [None]:
p