In [None]:
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
import glob

import deepposekit as dpk

import TrainingGeneratorTFRecord as TGTFR
import apt_dpk 
import run_apt_expts_2 as rae
import APT_interface as apt
import PoseTools as pt
import multiResData as mrd
import open_pose_data as opd
import util
import tfdatagen

import time
from os.path import expanduser

import mpl_toolkits.axes_grid1 as axg1

DPK_DSET = '/home/al/git/dpkd/datasets/fly/annotation_data_release.h5'

## Create a DPK DG; create an APT TFR. Confirm that we read the identical ims/locs from both.

In [None]:
dg = dpk.DataGenerator(DPK_DSET)

LEAPSTRIPPEDLBL = '/dat0/jrcmirror/groups/branson/bransonlab/apt/experiments/data/leap_dataset_gt_stripped.lbl'
EXPNAME = 'dpkfly'
CACHE = '/dat0/apt/cache'
conf = apt.create_conf(LEAPSTRIPPEDLBL, 0, EXPNAME, \
                       CACHE, 'dpkfly', quiet=False)
conf.img_dim = 1  # hack, the leap stripped lbl has NumChans=3, but we created the tfr 
                  # directly using the dpk h5 which is 1-channel
conf.cachedir

# Cached images in strippedlbl differ from dpk h5! 
# - Ims are 3-chan grayscale vs 1-chan
# - Locs are off-by-one; strippedlbl prob correct (0-based)
#apt.create_tfrecord(conf,split=False,use_cache=True)


# So rather than use apt.create_tfrecord, we create the tfr directly from the h5 even 
# though it may be off-by-one for replication purposes

train_tf = conf.cachedir + '/train_TF.tfrecords'
#apt_dpk.apt_db_from_datagen(dg, train_tf)

In [None]:
dir(dg)

In [None]:
pt.count_records(train_tf), mrd.read_tfrecord_metadata(train_tf)

In [None]:
conf.batch_size

In [None]:
# Create a dset
ds = tfdatagen.create_tf_datasets(conf, 
                                  bsize='UNUSED', 
                                  n_outputs=1,
                                  is_raw=True, 
                                  shuffle=False, 
                                  infinite=False, 
                                  dobatch=False)

In [None]:
# explicitly check/compare our TFR to the dg 
nkp = 32
INDICES_CHECK = [0,333,1499]
ims, locs, ifo, _ = mrd.read_and_decode_without_session(train_tf, nkp, indices=INDICES_CHECK)

In [None]:
ims0 = dg.get_images(INDICES_CHECK)
locs0 = dg.get_keypoints(INDICES_CHECK)

In [None]:
# ... and TFD!!
resTFD = tfdatagen.read_ds_idxed(ds, INDICES_CHECK)

In [None]:
imsTFD, locsTFD, ifoTFD = zip(*resTFD)

In [None]:
for i in range(3):
    print(np.array_equal(ims[i],imsTFD[i]), np.array_equal(locs[i],locsTFD[i]), np.array_equal(ifo[i],ifoTFD[i]))

In [None]:
for i in range(3):
    print(np.array_equal(ims[i],ims0[i,...]), np.array_equal(locs[i],locs0[i,...]))

In [None]:
ims0stk = np.moveaxis(ims0, 0, -1)
ims0stk = ims0stk[:,:,0,:]
ims0stk.shape

In [None]:
locs0.shape

In [None]:
tfdatagen.montage(ims0stk, locs=locs0)

## Create a DPK TG; create an APT TGTFR; create a TFD; all using no augmentation. Confirm that we read the identical ims/locs from both.

In [None]:
DSFAC = 2
SIGMA = 5
VALSPLIT = 0.0
GRAPHSCALE = 1
tg = dpk.TrainingGenerator(generator=dg,
                           downsample_factor=DSFAC,
                           use_graph=True,
                           shuffle=False,
                           sigma=SIGMA,
                           validation_split=VALSPLIT,
                           graph_scale=GRAPHSCALE,
                           random_seed=0)


In [None]:
tgc = tg.get_config()
tgc

In [None]:
tg.train_range

In [None]:
g = tg(n_outputs=1, batch_size=4, 
       validation=False, confidence=True)
g

In [None]:
util.dictdiff(vars(tg),vars(g))

In [None]:
train_tf

In [None]:
conf.dpk_input_sigma

In [None]:
dg.graph, dg.swap_index, dg.compute_image_shape()

In [None]:
conf.img_dim = 1  # hack, the leap stripped lbl has NumChans=3, but we created the tfr 
                  # directly using the dpk h5 which is 1-channel
conf = apt_dpk.update_conf_dpk(conf,
                               dg.graph,
                               dg.swap_index,
                               n_keypoints=dg.n_keypoints,
                               imshape=dg.compute_image_shape(),
                               useimgaug=False,
                               imgaugtype=None)
vars(conf)         

In [None]:
tgtfr = TGTFR.TrainingGeneratorTFRecord(conf)

In [None]:
c0 = tg.get_config()
c1 = tgtfr.get_config()
util.dictdiff(c0,c1)

In [None]:
c0

In [None]:
def xylist2xyarr(xylist, xisscalarlist=False):
    x, y  = zip(*xylist)
    if xisscalarlist:
        assert all([len(z)==1 for z in x])
        x = [z[0] for z in x]
    x = np.concatenate(x,axis=0)
    y = np.concatenate(y,axis=0)
    return x, y
    

In [None]:
# For tg, originally we constructed without any aug,
# and with VALSPLIT=0. So we will use the training imgs
# to get unaugmented ims
g = tg(n_outputs=1, batch_size=4, 
       validation=False, confidence=True)

In [None]:
# For tgtfr
# 1. we did not have a val db, so it will use the trndb
# 2. validation=True =>  no distort/aug/shuffle.

# NOW OBSOLETE, TGTFR/generator not updated for bsize

conf.batch_size = 4 # hack again, note tgtfr uses batch_size in conf and ignores input arg
gtf = tgtfr(n_outputs=1, batch_size=4, 
       validation=True, confidence=True)

In [None]:
import importlib

In [None]:
vars(conf)

In [None]:
conf_tgtfr = tgtfr.conf  # creating tgtfr has set a few things
ds = tfdatagen.create_tf_datasets(conf_tgtfr, 4, 1, 
                                  distort=False, shuffle=False, infinite=False)

In [None]:
imstgts_dpk = [g[x] for x in range(4)]
imsdpk, tgtsdpk = tfdatagen.xylist2xyarr(imstgts_dpk)
imsdpk.shape, tgtsdpk.shape

In [None]:
imstgts_apt = [next(gtf) for _ in range(4)]
imsapt, tgtsapt = xylist2xyarr(imstgts_apt, xisscalarlist=True)
imsapt.shape, tgtsapt.shape

In [None]:
resDS = tfdatagen.read_ds_idxed(ds, range(4))

In [None]:
for i in range(4):
    resDS[i] = (resDS[i][0], resDS[i][1][0]) 

In [None]:
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
print( np.allclose(imsdpk, imsapt), np.allclose(tgtsdpk, tgtsapt) )
print( np.array_equal(imsdpk, imsapt), np.array_equal(tgtsdpk, tgtsapt) )

In [None]:
print( np.allclose(imsdpk, imsDS), np.allclose(tgtsdpk, tgtsDS) )
print( np.array_equal(imsdpk, imsDS), np.array_equal(tgtsdpk, tgtsDS) )

In [None]:
imsdpkstk = np.moveaxis(imsdpk,0,-1)
imsdpkstk = imsdpkstk[:,:,0,:]
hfig, grid, cb0 = montage(imsdpkstk,cmap='gray')

In [None]:
imsDSstk = np.moveaxis(imsDS,0,-1)
imsDSstk = imsDSstk[:,:,0,:]
hfig, grid, cb0 = tfdatagen.montage(imsDSstk,cmap='gray')

## Create a TG with our default imgaug; v similar to DPK example notebook3; 
## Create a TGTFR with ". Reseed each IA, can we get reproducible augmented data?
## same with DS!!

In [None]:
ia = apt_dpk.make_imgaug_augmenter('dpkfly', dg)

In [None]:
DSFAC = 2
SIGMA = 5
VALSPLIT = 0.0
GRAPHSCALE = 1
tg = dpk.TrainingGenerator(generator=dg,
                           downsample_factor=DSFAC,
                           augmenter=ia,
                           use_graph=True,
                           shuffle=False,
                           sigma=SIGMA,
                           validation_split=VALSPLIT,
                           graph_scale=GRAPHSCALE,
                           random_seed=0)

# For tg, VALSPLIT=0 => use the training imgs
g = tg(n_outputs=1, batch_size=4, 
       validation=False, confidence=True)

RNGSEED = 17
g.augmenter.reseed(RNGSEED)
imstgts_dpk = [g[x] for x in range(3)]
imsdpk, tgtsdpk = tfdatagen.xylist2xyarr(imstgts_dpk)
imsdpk.shape, tgtsdpk.shape

In [None]:
conf = apt.create_conf(LEAPSTRIPPEDLBL, 0, EXPNAME, \
                       CACHE, 'dpkfly', quiet=False)
conf.img_dim = 1  # hack, the leap stripped lbl has NumChans=3, but we created the tfr 
                  # directly using the dpk h5 which is 1-channel
conf = apt_dpk.update_conf_dpk(conf,
                               dg.graph,
                               dg.swap_index,
                               n_keypoints=dg.n_keypoints,
                               imshape=dg.compute_image_shape(),
                               useimgaug=True,
                               imgaugtype='dpkfly')
conf.dpk_use_tfdata = True
tgtfr = TGTFR.TrainingGeneratorTFRecord(conf)
tgtfr.conf.dpk_augmenter.reseed(RNGSEED)

In [None]:
tgtfr.conf.cachedir

In [None]:
ds = tgtfr(batch_size=4, shuffle=False, infinite=False)

In [None]:
imstgts_apt = [next(gtf) for _ in range(3)]

imsapt, tgtsapt = xylist2xyarr(imstgts_apt, xisscalarlist=True)
imsapt.shape, tgtsapt.shape

In [None]:
resDS = tfdatagen.read_ds_idxed(ds, range(3))

In [None]:
for i in range(3):
    resDS[i] = (resDS[i][0], resDS[i][1][0]) 

In [None]:
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
np.allclose(imsdpk,imsapt,), np.allclose(tgtsdpk, tgtsapt), \
np.array_equal(imsdpk, imsapt), np.array_equal(tgtsdpk, tgtsapt)

In [None]:
np.allclose(imsdpk,imsDS,), np.allclose(tgtsdpk, tgtsDS), \
np.array_equal(imsdpk, imsDS), np.array_equal(tgtsdpk, tgtsDS)

In [None]:
imsstk = np.moveaxis(imsdpk, 0, -1)
imsstk = imsstk[:,:,0,:]
imsstk.shape

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsstk,cmap='gray')

In [None]:
imsstk = np.moveaxis(imsDS, 0, -1)
imsstk = imsstk[:,:,0,:]
imsstk.shape

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsstk,cmap='gray')

In [None]:
hfig, grid, cb0 = tfdatagen.montage(tgtsDS[-1,...],cmap='gray')

## Create an Expdir with Val so we can test val

In [None]:
dg = dpk.DataGenerator(DPK_DSET)

LEAPSTRIPPEDLBL = '/dat0/jrcmirror/groups/branson/bransonlab/apt/experiments/data/leap_dataset_gt_stripped.lbl'
EXPNAME = 'val10pct'
CACHE = '/dat0/apt/cache'
conf = apt.create_conf(LEAPSTRIPPEDLBL, 0, EXPNAME, \
                       CACHE, 'dpkfly', quiet=False)
conf.img_dim = 1  # hack, the leap stripped lbl has NumChans=3, but we created the tfr 
                  # directly using the dpk h5 which is 1-channel
conf.cachedir

# Cached images in strippedlbl differ from dpk h5! 
# - Ims are 3-chan grayscale vs 1-chan
# - Locs are off-by-one; strippedlbl prob correct (0-based)
#apt.create_tfrecord(conf,split=False,use_cache=True)


# So rather than use apt.create_tfrecord, we create the tfr directly from the h5 even 
# though it may be off-by-one for replication purposes



In [None]:
n=len(dg)

In [None]:
VAL_PCT = 10
idx_val = np.random.choice(range(n), size=(int(n*VAL_PCT/100)), replace=False) 

In [None]:
len(idx_val)

In [None]:
train_tf = conf.cachedir + '/train_TF.tfrecords'
val_tf = conf.cachedir + '/val_TF.tfrecords'
apt_dpk.apt_db_from_datagen(dg, train_tf, val_idx=idx_val, val_tf=val_tf)

### Note, the valtf is/was written to disk in sorted order by frame no

### Make a dpk.TG

In [None]:
idx_val.sort()

In [None]:
idx_val

In [None]:
_, _, ifo, _ = mrd.read_and_decode_without_session(val_tf, nkp, indices=())

In [None]:
ifo = np.array(ifo)
all(np.diff(ifo[:,0])>0)

In [None]:
dg

In [None]:
DSFAC = 2
SIGMA = 5
VALSPLIT = 0.0
GRAPHSCALE = 1
ia = apt_dpk.make_imgaug_augmenter('dpkfly', dg)
tg = dpk.TrainingGenerator(generator=dg,
                           downsample_factor=DSFAC,
                           augmenter=ia,
                           use_graph=True,
                           shuffle=False,
                           sigma=SIGMA,
                           validation_split=VALSPLIT,
                           graph_scale=GRAPHSCALE,
                           random_seed=0)

In [None]:
# Massage to use my val split
tg.val_index = idx_val
train_index = np.invert(np.isin(tg.index, tg.val_index))
tg.train_index = tg.index[train_index]
tg.n_validation = len(idx_val)

In [None]:
cfg = tg.get_config()

In [None]:
# For tg, VALSPLIT=0 => use the training imgs
g = tg(n_outputs=1, batch_size=4, 
       validation=True, confidence=False)

In [None]:
g.val_index[:10]

In [None]:
RNGSEED = 17
g.augmenter.reseed(RNGSEED)
imstgts_dpk = [g[x] for x in range(150//4)]
imsdpk, tgtsdpk = xylist2xyarr(imstgts_dpk)
imsdpk.shape, tgtsdpk.shape

In [None]:
tgtsdpk[-148,...]

In [None]:
conf.img_dim = 1
conf = apt_dpk.update_conf_dpk(conf,
                               dg.graph,
                               dg.swap_index,
                               n_keypoints=dg.n_keypoints,
                               imshape=dg.compute_image_shape(),
                               useimgaug=True,
                               imgaugtype='dpkfly')
conf.batch_size=4

In [None]:
import importlib
importlib.reload(tfdatagen)

In [None]:
ds = tfdatagen.create_tf_datasets(conf, is_val=True, distort=False, shuffle=False, infinite=False, drawconf=False)

In [None]:
dsraw = tfdatagen.create_tf_datasets(conf, is_val=True, is_raw=True, distort=False, shuffle=False, infinite=False, drawconf=False)

In [None]:
resDSraw = read_ds_idxed(dsraw,range(150//4))

In [None]:
resDSraw[0][2]

In [None]:
imsDSraw, locsDSraw, tgtsDSraw = xylist2xyarr(resDSraw)
imsDSraw.shape, locsDSraw,shape, tgtsDSraw.shape

In [None]:
resDS = read_ds_idxed(ds, range(150//4))

In [None]:
imsDS, tgtsDS = xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
tgtsDS[-140,...]

In [None]:
np.allclose(imsdpk,imsDS,), np.allclose(tgtsdpk, tgtsDS), \
np.array_equal(imsdpk, imsDS), np.array_equal(tgtsdpk, tgtsDS)

In [None]:
tgtsDS[2,:]

In [None]:
imsstk = np.moveaxis(imsDS, 0, -1)
imsstk = imsstk[:,:,0,:]
imsstk.shape

In [None]:
montage(imsDS[:10,...],locs=tgtsDS[:10,...],cmap='gray',locsmrkrsz=40)

In [None]:
## TODO: use pt for tgtfr. how does it compare?