In [18]:
import h5py
import numpy as np
import os
from matplotlib import pyplot as plt
na = np.newaxis

from renderer.renderer import MouseScene
from renderer.load_data import load_new_behavior_data

import experiments
from experiments import Experiment
import pose_models
import particle_filter
import predictive_models as pm
import predictive_distributions as pd
from util.text import progprint_xrange

root = os.path.dirname(os.path.abspath(os.getcwd()))

In [19]:
msNumRows, msNumCols = 32,32
ms = None
def _build_mousescene(scenefilepath):
    global ms
    if ms is None:
        ms = MouseScene(scenefilepath, mouse_width=80, mouse_height=80, \
                        scale_width = 18.0, scale_height = 200.0,
                        scale_length = 18.0, \
                        numCols=msNumCols, numRows=msNumRows, useFramebuffer=True,showTiming=False)
        ms.gl_init()
    return ms

def _load_data(datapath,frame_range):
    xy = load_new_behavior_data(datapath, frame_range[1] + 1, 'centroid')[frame_range[0]:]
    theta = load_new_behavior_data(datapath, frame_range[1] + 1, 'angle')[frame_range[0]:]
    xytheta = np.concatenate((xy, theta[:, na]), axis=1)
    images = load_new_behavior_data(datapath, frame_range[1] + 1, 'images').astype('float32')[frame_range[0]:]
    images_rot = np.array([image.T[::-1, :] for image in images])
    return images_rot, xytheta

In [20]:
class RandomWalkFixedNoiseFrozenTrack_AW_5Joints_simplified(Experiment):
    # should look a lot like RandomWalkFixedNoise
    def run(self,frame_range):
        datapath = os.path.join(root, "data/extracted_data/2019-03-20_17-31-38_saline_example_0_000070_results_00.h5")

        num_particles_firststep = 1024*10
        num_particles = 1024*10
        cutoff = 1024*10

        lag = 15

        pose_model = pose_models.PoseModel_5Joint_origweights_AW()
        
        variances = {
            'x':             {'init':3.0,  'subsq':1.5},
            'y':             {'init':3.0,  'subsq':1.5},
            'theta_yaw':     {'init':7.0,  'subsq':3.0},
            'z':             {'init':3.0,  'subsq':0.25},
            'theta_roll':    {'init':0.01, 'subsq':0.01},
            's_w':           {'init':2.0,  'subsq':1e-6},
            's_l':           {'init':2.0,  'subsq':1e-6},
            's_h':           {'init':1.0,  'subsq':1e-6}
        }
        particle_fields = pose_model.ParticlePose._fields
        joint_names = [j for j in particle_fields if 'psi_' in j]
        [variances.update({j:{'init':20.0, 'subsq':5.0}}) for j in joint_names]

        randomwalk_noisechol = np.diag([variances[p]['init'] for p in particle_fields])
        subsequent_randomwalk_noisechol = np.diag([variances[p]['subsq'] for p in particle_fields])

        # TODO check z size
        # TODO try cutting scale, fit on first 10 or so

        _build_mousescene(pose_model.scenefilepath)
        images, xytheta = _load_data(datapath,frame_range)
        
        pose_model.default_renderer_pose = \
            pose_model.default_renderer_pose._replace(theta_yaw=xytheta[0,2],x=xytheta[0,0],y=xytheta[0,1])
        pose_model.default_particle_pose = \
            pose_model.default_particle_pose._replace(theta_yaw=xytheta[0,2],x=xytheta[0,0],y=xytheta[0,1])

        def log_likelihood(stepnum,im,poses):
            return ms.get_likelihood(im,particle_data=pose_model.expand_poses(poses),
                x=xytheta[stepnum,0],y=xytheta[stepnum,1],theta=xytheta[stepnum,2])/2000.

        pf = particle_filter.ParticleFilter(
                pose_model.particle_pose_tuple_len,
                cutoff,
                log_likelihood,
                [particle_filter.AR(
                    num_ar_lags=1,
                    previous_outputs=(pose_model.default_particle_pose,),
                    baseclass=lambda: pm.RandomWalk(noiseclass=lambda: pd.FixedNoise(randomwalk_noisechol))
                    ) for itr in range(num_particles_firststep)])

        pf.step(images[0])
        pf.change_numparticles(num_particles)
        randomwalk_noisechol[:] = subsequent_randomwalk_noisechol[:]

        for i in progprint_xrange(1,lag):
            pf.step(images[i])
        self.save_progress(pf,pose_model,datapath,frame_range,means=[])

        # now step with freezing means
        means = []
        for i in progprint_xrange(lag,images.shape[0],perline=10):
            means.append(np.sum(pf.weights_norm[:,na] * np.array([p.track[i-lag] for p in pf.particles]),axis=0))
            print('\nsaved a mean for index %d with %d unique particles!\n' % \
                    (i-lag,len(np.unique([p.track[i-15][0] for p in pf.particles]))))

            pf.step(images[i])

            if (i % 5) == 0:
                self.save_progress(pf,pose_model,datapath,frame_range,means=means)

        self.save_progress(pf,pose_model,datapath,frame_range,means=means)

In [26]:
# ==================================================
#                     SERIAL
# ==================================================
# This is how we do things serially

RandomWalkFixedNoiseFrozenTrack_AW_5Joints_simplified((100,200)).run((100,200))

results directory: results/8218897506270013465.100.200
169.34881835795883
1776
1991.95768844505
4443
.5641.54581932346
6199
.7926.947096973737
7993
.8503.386966865346
8393
.8440.144780359376
8336
.8334.700937448244
8269
.8425.028192425756
8342
.8417.867108933408
8281
.8167.201887606497
8147
.8092.248678984923
8049
.8317.604246836408
8174
.8191.356847945012
8065
.8401.68791893004
8199
.8320.87132170167
8173
.
   1.73sec avg,   24.23sec total


saved a mean for index 0 with 1 unique particles!

8705.024344430696
8450
.
saved a mean for index 1 with 255 unique particles!

8676.033527570204
8419
.
saved a mean for index 2 with 461 unique particles!

8978.744195479627
8664
.
saved a mean for index 3 with 511 unique particles!

8842.6813657779
8586
.
saved a mean for index 4 with 516 unique particles!

9000.862285425857
8655
.
saved a mean for index 5 with 537 unique particles!

9239.540270318616
8898
.
saved a mean for index 6 with 560 unique particles!

9138.165269603785
8777
.
saved a mea

.
saved a mean for index 2 with 743 unique particles!

8942.536283413334
8657
.
saved a mean for index 3 with 728 unique particles!

8834.61937613716
8557
.
saved a mean for index 4 with 696 unique particles!

8951.206497726893
8644
.
saved a mean for index 5 with 665 unique particles!

9135.282522129703
8813
.
saved a mean for index 6 with 646 unique particles!

9065.512862188927
8744
.
saved a mean for index 7 with 624 unique particles!

9136.134119941602
8795
.
saved a mean for index 8 with 622 unique particles!

9194.923693755894
8846
.
saved a mean for index 9 with 647 unique particles!

8909.37601089691
8610
.  [ 10/86,    2.39sec avg,  181.99sec ETA ]

saved a mean for index 10 with 690 unique particles!

8950.91642218881
8677
.
saved a mean for index 11 with 743 unique particles!

9102.230968489353
8786
.
saved a mean for index 12 with 781 unique particles!

8996.696595978807
8721
.
saved a mean for index 13 with 832 unique particles!

9191.31098314514
8870
.
saved a mean for i

In [27]:
import visualization

In [28]:
root = os.path.dirname(os.path.abspath(os.getcwd()))
pickle_file = os.path.join(root,"pyparticles/results/8218897506270013465.100.200/96")
out_file = os.path.join(root,"data/results_movies/mouse13")
from pathlib import Path
Path(out_file).mkdir(parents=True, exist_ok=True)

In [29]:
visualization.meantrack_sidebyside_movie(pickle_file,out_file)

In [30]:
import cv2
import os

image_folder = out_file
video_name = os.path.join(root,'data/results_movies/mouse13_video1.avi')

images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

video = cv2.VideoWriter(video_name, 0, 10, (width,height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()

In [14]:
from renderer import load_data
datapath = os.path.join(root, "data/extracted_data/2019-03-20_17-31-38_saline_example_0_000070_results_00.h5")
test_dict = load_data.h5_to_dict(datapath)

In [17]:
test_dict.keys()

dict_keys(['frames', 'frames_mask', 'metadata', 'scalars', 'timestamps'])