In [None]:
import glob
import importlib
import time
from os.path import expanduser
import copy

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import mpl_toolkits.axes_grid1 as axg1

import deepposekit as dpk
import TrainingGeneratorTFRecord as TGTFR
import apt_dpk 
import apt_dpk_exps as ade
import run_apt_expts_2 as rae
import APT_interface as apt
import PoseTools as pt
import multiResData as mrd
import open_pose_data as opd
import util
import tfdatagen

DPK_DSET = '/home/al/git/dpkd/datasets/fly/annotation_data_release_AL.h5'

LEAPSTRIPPEDLBL = '/dat0/jrcmirror/groups/branson/bransonlab/apt/experiments/data/leap_dataset_gt_stripped.lbl'
EXPNAME = 'dpkfly'
CACHE = '/dat0/apt/cache'

dg = dpk.DataGenerator(DPK_DSET)


## Create a DPK TG with default IA aug; montage it

In [None]:
conf = apt.create_conf(LEAPSTRIPPEDLBL, 0, EXPNAME, \
                       CACHE, 'dpkfly', quiet=False)
conf.img_dim = 1  # hack, the leap stripped lbl has NumChans=3, but we created the tfr 
                  # directly using the dpk h5 which is 1-channel
conf.cachedir

# Cached images in strippedlbl differ from dpk h5! 
# - Ims are 3-chan grayscale vs 1-chan
# - Locs are off-by-one; strippedlbl prob correct (0-based)
#apt.create_tfrecord(conf,split=False,use_cache=True)

# So rather than use apt.create_tfrecord, we create the tfr directly from the h5 even 
# though it may be off-by-one for replication purposes

#train_tf = conf.cachedir + '/train_TF.tfrecords'
#apt_dpk.apt_db_from_datagen(dg, train_tf)

In [None]:
DSFAC = 2
SIGMA = 5
VALSPLIT = 0.0
GRAPHSCALE = 1

ia = apt_dpk.make_imgaug_augmenter('dpkfly', dg)
tg = dpk.TrainingGenerator(generator=dg,
                           downsample_factor=DSFAC,
                           augmenter=ia,
                           use_graph=True,
                           shuffle=False,
                           sigma=SIGMA,
                           validation_split=VALSPLIT,
                           graph_scale=GRAPHSCALE,
                           random_seed=0)

# For tg, VALSPLIT=0 => use the training imgs
g = tg(n_outputs=1, batch_size=4, 
       validation=False, confidence=True)

In [None]:
RNGSEED = 17
g.augmenter.reseed(RNGSEED)
imstgts_dpk = [g[x] for x in range(6)]
imsdpk, tgtsdpk = tfdatagen.xylist2xyarr(imstgts_dpk)
imsdpk.shape, tgtsdpk.shape

In [None]:
imsdpk = np.moveaxis(imsdpk, 0, -1)
imsdpk = imsdpk[:, :, 0, :]

In [None]:
def get_conf_map_argmaxs_rescaled(tgts0, ipt=31, sz=(192, 192)):
    tgts = tgts0[...,ipt].copy()
    tgts = np.moveaxis(tgts, 0, -1)
    tgts2 = cv2.resize(tgts, sz)
    locs = pt.get_pred_locs(tgts2[np.newaxis,...])
    locs = locs[0,...]
    return locs # [tgts0.shape[0] x 2]

In [None]:
locs = get_conf_map_argmaxs_rescaled(tgtsdpk)

In [None]:
locs.shape

In [None]:
locs = locs[:,np.newaxis,:]

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsdpk, cmap='gray', locs=locs, locsmrkrsz=128)

## Create a TGTFR/TFDS with default IA aug; montage it

In [None]:
conf.cachedir

In [None]:
# Create a TGTFR with our default imgaug
conf.img_dim = 1
conf = apt_dpk.update_conf_dpk(conf,
                               dg.graph,
                               dg.swap_index,
                               n_keypoints=dg.n_keypoints,
                               imshape=dg.compute_image_shape(),
                               useimgaug=True,
                               imgaugtype='dpkfly')
conf.dpk_use_tfdata = True
tgtfr = TGTFR.TrainingGeneratorTFRecord(conf)
tgtfr.conf.dpk_augmenter.reseed(RNGSEED)
conf.batch_size = 4 # hack again, note tgtfr uses batch_size in conf and ignores input arg
ds = tgtfr(n_outputs=1, batch_size=4, 
           validation=False, confidence=True, shuffle=False)

resDS = tfdatagen.read_ds_idxed(ds,range(6))
for i in range(6):
    resDS[i] = (resDS[i][0], resDS[i][1][0])  # tfds puts tgts in list even for n_outputs=1
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
imsDS = np.moveaxis(imsDS, 0, -1)
imsDS = imsDS[:, :, 0, :]
locsDS = get_conf_map_argmaxs_rescaled(tgtsDS)

In [None]:
locsDS = locsDS[:,np.newaxis,:]

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsDS, cmap='gray', locs=locsDS, locsmrkrsz=128)

## Create a TGTFR/TFDS with PT aug; montage it
## Oops, this used the default dataaug params i guess

In [None]:
conf.dpk_use_augmenter = False
tgtfrPT = TGTFR.TrainingGeneratorTFRecord(conf)
#tgtfr.conf.dpk_augmenter.reseed(RNGSEED)
conf.batch_size = 4 # hack again, note tgtfr uses batch_size in conf and ignores input arg
ds = tgtfrPT(n_outputs=1, batch_size=4, 
             validation=False, confidence=True, shuffle=False)

resDS = tfdatagen.read_ds_idxed(ds,range(3))
for i in range(3):
    resDS[i] = (resDS[i][0], resDS[i][1][0]) 
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
imsDS = np.moveaxis(imsDS, 0, -1)
imsDS = imsDS[:, :, 0, :]
locsDS = get_conf_map_argmaxs_rescaled(tgtsDS)

In [None]:
locsDS = locsDS[:,np.newaxis,:]

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsDS,cmap='gray', locs=locsDS, locsmrkrsz=128)

In [None]:
tdf='/dat0/jrcmirror/groups/branson/bransonlab/apt/dl.al.2020/cache/multitarget_bubble/openpose/view_0/apt_expts_opgal/multitarget_bubble_deepnet_20200429_traindata'
td=pt.pickle_load(tdf)

In [None]:
conf_bub = td[-1]

In [None]:
conf_bub.print_dataaug_flds()

In [None]:
skelcsv = apt_dpk.skeleton_csvs['alice'][0]
graph, swapidx = apt_dpk.read_skel_csv(skelcsv)
swapidx

In [None]:
flm_flybub = apt_dpk.swap_index_to_flip_landmark_matches(swapidx)

In [None]:
flm_flybub==conf_bub.flipLandmarkMatches

In [None]:
## sometimes 75%:

 # flip: both flips => true. lmark matches etc. want random prob
 # scale: (0.95, 1.05) or (95%->1.05%) in r/c independently (SKIPPED IN PT IMPL)
 # xlate: (-.05,.05) or +/- 5% of im size in r/c independently
 # shear: (-8,8) or (-8deg,8deg) (SKIPPED IN PT IMPL)
 # order = ia.ALL, cval=ia.ALL, 

 # THEN another scale (0.9, 1.1), this is uniform x/y.

## then always rorate -180:180

In [None]:
conf_exp2 = copy.deepcopy(conf)
ade.exp2_set_posetools_aug_config_leapfly(conf_exp2)

In [None]:
conf_exp2.print_dataaug_flds()

In [None]:
# check conf/dataaug against ia cmts
# diff conf against orig conf
# how to run, set dpk_use_augmenter etc?

In [None]:
util.dictdiff(conf,conf_exp2)

## Create a TGTFR/TFDS with PT aug; montage it
## Take 2!!

In [None]:
# Create a TGTFR with our default imgaug
conf_exp2 = apt.create_conf(LEAPSTRIPPEDLBL, 0, EXPNAME, \
                            CACHE, 'dpkfly', quiet=False)
conf_exp2.img_dim = 1
conf_exp2 = apt_dpk.update_conf_dpk(conf_exp2,
                                    dg.graph,
                                    dg.swap_index,
                                    n_keypoints=dg.n_keypoints,
                                    imshape=dg.compute_image_shape(),
                                    useimgaug=False,
                                    imgaugtype='dpkfly')
conf_exp2.dpk_use_tfdata = True
conf_exp2.dpk_use_augmenter = False
ade.exp2_set_posetools_aug_config_leapfly(conf_exp2)
conf_exp2.print_dataaug_flds()

In [None]:
dg.swap_index

In [None]:
conf_exp2.dpk_use_tfdata

In [None]:
conf_exp2.batch_size = 4 # hack again, note tgtfr uses batch_size in conf and ignores input arg
tgtfr = TGTFR.TrainingGeneratorTFRecord(conf_exp2)
#tgtfr.conf.dpk_augmenter.reseed(RNGSEED)
ds = tgtfr(n_outputs=1, batch_size=4, 
           validation=False, confidence=True, shuffle=False)

resDS = tfdatagen.read_ds_idxed(ds,range(6))
for i in range(6):
    resDS[i] = (resDS[i][0], resDS[i][1][0])  # tfds puts tgts in list even for n_outputs=1
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
imsDS = np.moveaxis(imsDS, 0, -1)
imsDS = imsDS[:, :, 0, :]
locsDS = get_conf_map_argmaxs_rescaled(tgtsDS)

In [None]:
locsDS = locsDS[:,np.newaxis,:]

In [None]:
locsDS.shape

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsDS, cmap='gray', locs=locsDS, locsmrkrsz=128)

In [None]:
conf_exp2.dpk_use_augmenter

## Check extents

In [None]:
ds = tgtfr(n_outputs=1, batch_size=4, 
           validation=False, confidence=True, shuffle=False)
NBCH = 200
resDS = tfdatagen.read_ds_idxed(ds,range(NBCH))
for i in range(NBCH):
    resDS[i] = (resDS[i][0], resDS[i][1][0])  # tfds puts tgts in list even for n_outputs=1
imsDS, tgtsDS = tfdatagen.xylist2xyarr(resDS)
imsDS.shape, tgtsDS.shape

In [None]:
tg = dpk.TrainingGenerator(generator=dg,
                           downsample_factor=DSFAC,
                           augmenter=ia,
                           use_graph=True,
                           shuffle=False,
                           sigma=SIGMA,
                           validation_split=VALSPLIT,
                           graph_scale=GRAPHSCALE,
                           random_seed=0)

# For tg, VALSPLIT=0 => use the training imgs
g = tg(n_outputs=1, batch_size=4, 
       validation=False, confidence=True)

RNGSEED = 17
g.augmenter.reseed(RNGSEED)
imstgts_dpk = [g[x] for x in range(NBCH)]
imsdpk, tgtsdpk = tfdatagen.xylist2xyarr(imstgts_dpk)
imsdpk.shape, tgtsdpk.shape

In [None]:
def get_conf_map_argmaxs_allmaps(tgts0, sz=(192, 192)):
    '''
    tgts0: [bsize x nr x nc x nmap]
    '''
    tgts = tgts0.copy()
    tgts = np.moveaxis(tgts, 0, -1)  # [nr x nc x nmap x bsize]
    nr, nc, nmap, bsize = tgts.shape
    print('nr nc nmap bsize = {} {} {} {}'.format(nr, nc, nmap, bsize))
    tgts = np.reshape(tgts, (nr, nc, nmap*bsize), order='F')
    
    locs = np.zeros((nmap*bsize, 2))
    for i in range(tgts.shape[2]):
        tgttmp = cv2.resize(tgts[...,i], sz)
        locstmp = pt.get_pred_locs(tgttmp[np.newaxis,...,np.newaxis])
        locs[i,...] = locstmp[0,0,...]
        
    return locs # [bsize*nmap x 2]

In [None]:
locsDS = get_conf_map_argmaxs_allmaps(tgtsDS)
locsDS.shape

In [None]:
locsdpk = get_conf_map_argmaxs_allmaps(tgtsdpk)
locsdpk.shape

In [None]:
f = plt.figure(num=21)
f.clf()
ax = f.subplots(2,2,sharex=True,sharey=True)  # figsize=(16,12),
for iax in range(2):
    plt.sca(ax[0,iax])
    plt.hist(locsdpk[:,iax],bins=50)
    plt.title('dpk {}'.format(iax))
    plt.sca(ax[1,iax])
    plt.hist(locsDS[:,iax],bins=50)
    plt.title('DS {}'.format(iax))
    


In [None]:
imsDSplot = imsDS.copy()
imsDSplot = np.moveaxis(imsDSplot, 0, -1)
imsDSplot = imsDSplot[:, :, 0, :]
locsDSplot = np.reshape(locsDS, (66, 800, 2), order='F')
locsDSplot = np.moveaxis(locsDSplot, 0, 1)
locsDSplot.shape

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsDSplot[...,:100], 
                                    cmap='gray', 
                                    locs=locsDSplot[:100,:32,:], 
                                    locsmrkrsz=12)

In [None]:
imsdpkplot = imsdpk.copy()
imsdpkplot = np.moveaxis(imsdpkplot, 0, -1)
imsdpkplot = imsdpkplot[:, :, 0, :]
locsdpkplot = np.reshape(locsdpk, (66, 800, 2), order='F')
locsdpkplot = np.moveaxis(locsdpkplot, 0, 1)
NPLOT = 4
hfig, grid, cb0 = tfdatagen.montage(imsdpkplot[...,:NPLOT], 
                                    cmap='gray', 
                                    locs=locsdpkplot[:NPLOT,[0,25],:], 
                                    locsmrkrsz=12)

In [None]:
locsdpkplot.shape

In [None]:
IPTPLOT = 8
f = plt.figure(num=22)
f.clf()
ax = f.subplots(2,2,sharex=True,sharey=True)  # figsize=(16,12),
for iax in range(2):
    plt.sca(ax[0,iax])
    plt.hist(locsdpkplot[:,IPTPLOT,iax],bins=20)
    plt.title('dpk {}'.format(iax))
    plt.xlim((0,192))
    plt.sca(ax[1,iax])
    plt.hist(locsDSplot[:,IPTPLOT,iax],bins=20)
    plt.title('DS {}'.format(iax))
    plt.xlim((0,192))
    

In [None]:
IPTHEAD = 0
IPTTAIL = 5
vtailhead_ds = locsDSplot[:,IPTHEAD,:]-locsDSplot[:,IPTTAIL,:]
vtailhead_dpk = locsdpkplot[:,IPTHEAD,:]-locsdpkplot[:,IPTTAIL,:]

phi_ds = np.arctan2(vtailhead_ds[:,1],vtailhead_ds[:,0])
phi_dpk = np.arctan2(vtailhead_dpk[:,1],vtailhead_dpk[:,0])

phi_ds.shape, phi_dpk.shape

In [None]:
f = plt.figure(num=23)
f.clf()
ax = f.subplots(1,2,sharex=True,sharey=True)  # figsize=(16,12),

plt.sca(ax[0])
plt.hist(phi_ds,bins=20)
plt.title('ds')
#plt.xlim((0,192))
plt.sca(ax[1])
plt.hist(phi_dpk,bins=20)
plt.title('dpk')
#plt.xlim((0,192))