In [None]:
import glob
import importlib
import time
from os.path import expanduser
from os.path import join
import copy

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import mpl_toolkits.axes_grid1 as axg1

import deepposekit as dpk
import TrainingGeneratorTFRecord as TGTFR
import apt_dpk 
import apt_dpk_exps as ade
import run_apt_expts_2 as rae
import APT_interface as apt
import PoseTools as pt
import multiResData as mrd
import util
import tfdatagen

RESDIR = '/dat0/apt/res/'
RESPFIX = 'aug_bub_'
SLBL = '/dat0/jrcmirror/groups/branson/bransonlab/apt/experiments/data/multitarget_bubble_expandedbehavior_20180425_FxdErrs_OptoParams20200317_stripped20200403.lbl'
EXPNAME = 'multitarget_bubble'
CACHE = '/dat0/apt/cache'
SKEL = '/dat0/jrcmirror/groups/branson/bransonlab/apt/experiments/data/multitarget_bubble_dpk_skeleton.csv'

### Get the bub: i) train.tfr; ii) slbl (i guess) iii) skel


In [None]:
# Create a TGTFR with our default imgaug
conf_params = ['dpk_skel_csv', '"{}"'.format(SKEL)]
conf = apt.create_conf(SLBL, 
                       0, 
                       EXPNAME,
                       CACHE, 
                       'dpk', 
                       quiet=False,
                       conf_params=conf_params
                       )

In [None]:
conf.print_dataaug_flds(print)

In [None]:
conf.cachedir, conf.dpk_use_augmenter

### Make a TGTFR to generate raw/undistorted ims/kps from train_TF; montage
### This is actually "raw but with confmaps, then extracting locs"

In [None]:
TGTFR_BSIZE = 4

In [None]:
confraw = copy.deepcopy(conf)
confraw.valfilename = confraw.trainfilename
tgtfrraw = TGTFR.TrainingGeneratorTFRecord(confraw)
dsraw = tgtfrraw(n_outputs=1, 
                 batch_size=TGTFR_BSIZE, 
                 validation=True,
                 confidence=True,
                 shuffle=False,
                )
# val=>undistorted, but with confmaps and then extracting locs

In [None]:
def read_batches_from_dataset(ds, nbch_read=3):
    res = tfdatagen.read_ds_idxed(ds, range(nbch_read))
    ims, tgts = tfdatagen.xylist2xyarr(res, yisscalarlist=True)
    print("ims.shape={}, tgts.shape={}".format(ims.shape, tgts.shape))
    return ims, tgts

def get_pred_locs_with_unscale(ims, tgts):
    scalex = ims.shape[2]/tgts.shape[2]
    scaley = ims.shape[1]/tgts.shape[1]
    print("scalex={}, scaley={}".format(scalex, scaley))
    
    locs = pt.get_pred_locs(tgts)
    locs2 = pt.unscale_points(locs, scalex, scaley)
    return locs2

def savegcf(fig, fname):
    outfile = join(RESDIR, RESPFIX+fname)
    fig.savefig(outfile, bbox_inches='tight')
    print("Saved {}".format(outfile))

In [None]:
imsraw, tgtsraw = read_batches_from_dataset(dsraw)

locsraw = get_pred_locs_with_unscale(imsraw, tgtsraw[...,:17])

In [None]:
from matplotlib.colors import ListedColormap
lcm = ListedColormap(['r','r','w','w','r','w','r','c','c','m','m','r','r','r','g','g','g'])
LOCS_MRKRSZ = 80

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imsraw, 
                  ims0type='batchfirst', 
                  locs=locsraw, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )

savegcf(fig, 'raw_wconfmaps')

### Make a TGTFR to read raw/undistorted ims/kps from train_TF; montage
### This is "true raw"

In [None]:
dsrawtrue = tgtfrraw(n_outputs=1, 
                     batch_size=TGTFR_BSIZE, 
                     validation=True,
                     confidence=False,
                     shuffle=False,
                    )

In [None]:
nbch_read = 3
resrawtrue = tfdatagen.read_ds_idxed(dsrawtrue, range(nbch_read))
imsrawtrue, locsrawtrue = tfdatagen.xylist2xyarr(resrawtrue)
print("ims.shape={}, tgts.shape={}".format(imsrawtrue.shape, locsrawtrue.shape))

In [None]:
locsrawtrue  # note not at half-pxs

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imsrawtrue, 
                  ims0type='batchfirst', 
                  locs=locsrawtrue, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )
savegcf(fig, 'raw')






### Make a TGTFR with the bub PT: montage it
### Recall: the pipeline is, raw ims/locs read from tfr, padded, distorted, confmaps gen'd

In [None]:
np.random.seed(0)

In [None]:
tgtfr = TGTFR.TrainingGeneratorTFRecord(conf)
#tgtfr.conf.dpk_augmenter.reseed(RNGSEED)
dspt = tgtfr(n_outputs=1, 
             batch_size=TGTFR_BSIZE, 
             validation=False,  # train_TF=confraw.val_TF; but turn on distort
             confidence=True, 
             shuffle=False)

In [None]:
imspt, tgtspt = read_batches_from_dataset(dspt, 9)

locspt = get_pred_locs_with_unscale(imspt, tgtspt[...,:17])

In [None]:
locspt  # note even half-pxs

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imspt, 
                  ims0type='batchfirst', 
                  locs=locspt, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )
savegcf(fig, 'pt_wconfmaps')

In [None]:
conf.dpk_use_augmenter






### PT -> raw (conf off)

In [None]:
np.random.seed(0)

In [None]:
#tgtfr = TGTFR.TrainingGeneratorTFRecord(conf)
dsptraw = tgtfr(n_outputs=1, 
             batch_size=TGTFR_BSIZE, 
             validation=False,  # train_TF=confraw.val_TF; but turn on distort
             confidence=False, 
             shuffle=False)

In [None]:
nbch_read = 3
resptraw = tfdatagen.read_ds_idxed(dsptraw, range(nbch_read))
imsptraw, locsptraw = tfdatagen.xylist2xyarr(resptraw)
print("ims.shape={}, tgts.shape={}".format(imsptraw.shape, locsptraw.shape))

In [None]:
locsptraw

In [None]:
importlib.reload(tfdatagen)

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imsptraw, 
                  ims0type='batchfirst', 
                  locs=locsptraw, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )
savegcf(fig, 'pt_raw')






### Direct IA manip of raw ims


In [None]:
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables import Keypoint, KeypointsOnImage
from deepposekit.augment import FlipAxis

In [None]:
print(imsrawtrue.shape,locsrawtrue.shape)

In [None]:
aug = apt_dpk.make_imgaug_augmenter('bub', conf.dpk_swap_index)

In [None]:
imsia, locsia = tfdatagen.imgaug_augment(aug, 
                                         images=imsrawtrue, 
                                         keypoints=locsrawtrue)
print("imsia.shape={}, locsia.shape={}".format(imsia.shape, locsia.shape))


In [None]:
#imsia2 = np.minimum(255, imsia)

In [None]:
#imsia3 = imsia2.clip(0, 255)

In [None]:
np.min(imsia), np.max(imsia)

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsia, 
                                    ims0type='batchfirst',
                                    cmap='gray', 
                                    locs=locsia, 
                                    locsmrkrsz=55, locsmrkr='+', locscmap=lcm)

savegcf(hfig, 'ia_directmanip')






### end2end ia thru TGTFR


In [None]:
confia = copy.deepcopy(conf)
confia.dpk_use_augmenter = True
confia.dpk_augmenter_type = {'type': 'bub'}

In [None]:
tgtfria = TGTFR.TrainingGeneratorTFRecord(confia)
#tgtfr.conf.dpk_augmenter.reseed(RNGSEED)
dsia = tgtfria(n_outputs=1, 
             batch_size=TGTFR_BSIZE, 
             validation=False,  # train_TF=confraw.val_TF; but turn on distort
             confidence=True, 
             shuffle=False)

In [None]:
imsia, tgtsia = read_batches_from_dataset(dsia, 9)

locsia = get_pred_locs_with_unscale(imsia, tgtsia[...,:17])

In [None]:
locsia

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imsia, 
                  ims0type='batchfirst', 
                  locs=locsia, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )
savegcf(fig, 'ia_confmap')

### end2end ia thru TGTFR, but no conf


In [None]:
dsiaraw = tgtfria(n_outputs=1, 
                  batch_size=TGTFR_BSIZE, 
                  validation=False,  # train_TF=confraw.val_TF; but turn on distort
                  confidence=False, 
                  shuffle=False)

In [None]:
nbch_read = 3
resiaraw = tfdatagen.read_ds_idxed(dsiaraw, range(nbch_read))
imsiaraw, locsiaraw = tfdatagen.xylist2xyarr(resiaraw)
print("ims.shape={}, tgts.shape={}".format(imsiaraw.shape, locsiaraw.shape))

In [None]:
locsiaraw

In [None]:
fig, _, _ = tfdatagen.montage(ims0=imsiaraw, 
                  ims0type='batchfirst', 
                  locs=locsiaraw, 
                  cmap='gray',
                  locsmrkr='+',
                  locsmrkrsz=LOCS_MRKRSZ,
                  locscmap=lcm
                 )
savegcf(fig, 'ia_raw')

##  TODO BELOW

## Check extents

In [None]:
def get_conf_map_argmaxs_allmaps(tgts0, sz=(192, 192)):
    '''
    tgts0: [bsize x nr x nc x nmap]
    '''
    tgts = tgts0.copy()
    tgts = np.moveaxis(tgts, 0, -1)  # [nr x nc x nmap x bsize]
    nr, nc, nmap, bsize = tgts.shape
    print('nr nc nmap bsize = {} {} {} {}'.format(nr, nc, nmap, bsize))
    tgts = np.reshape(tgts, (nr, nc, nmap*bsize), order='F')
    
    locs = np.zeros((nmap*bsize, 2))
    for i in range(tgts.shape[2]):
        tgttmp = cv2.resize(tgts[...,i], sz)
        locstmp = pt.get_pred_locs(tgttmp[np.newaxis,...,np.newaxis])
        locs[i,...] = locstmp[0,0,...]
        
    return locs # [bsize*nmap x 2]

In [None]:
locsDS = get_conf_map_argmaxs_allmaps(tgtsDS)
locsDS.shape

In [None]:
locsdpk = get_conf_map_argmaxs_allmaps(tgtsdpk)
locsdpk.shape

In [None]:
f = plt.figure(num=21)
f.clf()
ax = f.subplots(2,2,sharex=True,sharey=True)  # figsize=(16,12),
for iax in range(2):
    plt.sca(ax[0,iax])
    plt.hist(locsdpk[:,iax],bins=50)
    plt.title('dpk {}'.format(iax))
    plt.sca(ax[1,iax])
    plt.hist(locsDS[:,iax],bins=50)
    plt.title('DS {}'.format(iax))
    


In [None]:
imsDSplot = imsDS.copy()
imsDSplot = np.moveaxis(imsDSplot, 0, -1)
imsDSplot = imsDSplot[:, :, 0, :]
locsDSplot = np.reshape(locsDS, (66, 800, 2), order='F')
locsDSplot = np.moveaxis(locsDSplot, 0, 1)
locsDSplot.shape

In [None]:
hfig, grid, cb0 = tfdatagen.montage(imsDSplot[...,:100], 
                                    cmap='gray', 
                                    locs=locsDSplot[:100,:32,:], 
                                    locsmrkrsz=12)

In [None]:
imsdpkplot = imsdpk.copy()
imsdpkplot = np.moveaxis(imsdpkplot, 0, -1)
imsdpkplot = imsdpkplot[:, :, 0, :]
locsdpkplot = np.reshape(locsdpk, (66, 800, 2), order='F')
locsdpkplot = np.moveaxis(locsdpkplot, 0, 1)
NPLOT = 100
hfig, grid, cb0 = tfdatagen.montage(imsdpkplot[...,:NPLOT], 
                                    cmap='gray', 
                                    locs=locsdpkplot[:NPLOT,[0,25],:], 
                                    locsmrkrsz=12)

In [None]:
locsdpkplot.shape

In [None]:
IPTPLOT = 8
f = plt.figure(num=22)
f.clf()
ax = f.subplots(2,2,sharex=True,sharey=True)  # figsize=(16,12),
for iax in range(2):
    plt.sca(ax[0,iax])
    plt.hist(locsdpkplot[:,IPTPLOT,iax],bins=20)
    plt.title('dpk {}'.format(iax))
    plt.xlim((0,192))
    plt.sca(ax[1,iax])
    plt.hist(locsDSplot[:,IPTPLOT,iax],bins=20)
    plt.title('DS {}'.format(iax))
    plt.xlim((0,192))
    

In [None]:
IPTHEAD = 0
IPTTAIL = 5
vtailhead_ds = locsDSplot[:,IPTHEAD,:]-locsDSplot[:,IPTTAIL,:]
vtailhead_dpk = locsdpkplot[:,IPTHEAD,:]-locsdpkplot[:,IPTTAIL,:]

phi_ds = np.arctan2(vtailhead_ds[:,1],vtailhead_ds[:,0])
phi_dpk = np.arctan2(vtailhead_dpk[:,1],vtailhead_dpk[:,0])

phi_ds.shape, phi_dpk.shape

In [None]:
f = plt.figure(num=23)
f.clf()
ax = f.subplots(1,2,sharex=True,sharey=True)  # figsize=(16,12),

plt.sca(ax[0])
plt.hist(phi_ds,bins=20)
plt.title('ds')
#plt.xlim((0,192))
plt.sca(ax[1])
plt.hist(phi_dpk,bins=20)
plt.title('dpk')
#plt.xlim((0,192))

### Messing around

In [None]:


aug2 = []
#aug2.append(iaa.Add(150))
aug2.append(iaa.LinearContrast((1.0,3.0)))
aug2 = iaa.Sequential(aug2)

im = ia.quokka(size=(32,32))
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
im = im.astype(np.float32)
kp = KeypointsOnImage([
    Keypoint(x=15, y=25),
    Keypoint(x=25, y=20),
], shape=im.shape)

ima, kpa = aug2(image=im, keypoints=kp)

im0 = kp.draw_on_image(im,color=0)
im1 = kpa.draw_on_image(ima,color=0)


plt.figure()
plt.imshow(im0)
plt.clim((0,255))
plt.colorbar()
plt.figure()
plt.imshow(im1)
plt.clim((0,255))
plt.colorbar()

np.max(im), np.max(ima)

In [None]:
imbub = imsDS[...,0]
imbub = imbub[np.newaxis,...]


In [None]:
kpbub = locsDS[0,np.newaxis,:17,:]

In [None]:
imbub.shape, kpbub.shape

In [None]:
imbub2, kpbub2 = augmenter(image=imbub, keypoints=kpbub)


In [None]:
imbub2.shape, kpbub2.shape

In [None]:
tfdatagen.montage(np.moveaxis(imbub,0,-1), locs=kpbub,figsize=(10,10))

In [None]:
tfdatagen.montage(np.moveaxis(imbub2,0,-1), locs=kpbub2,figsize=(10,10))