Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
import scipy.linalg
from scipy.stats import multivariate_normal, uniform
from scipy.special import logsumexp

In [None]:
def transition_matrix(dt):
    A = np.array([[1, dt, 0.5 * (dt**2)],
                  [0, 1,            dt],
                  [0, 0,            1]])
    return scipy.linalg.block_diag(A, A)

observation_matrix = np.array([
    [1, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0],
])

frame_height = 480
frame_width = 640

In [None]:
def make_synthetic_ball(start_x, start_y, step_count):
    # Simulation parameters
    dt = 0.1
    
    # State data
    
    # x, vx, ax, y, vy, ay
    ball_state = np.array([start_x, 0, 0, start_y, 0, 0])
    cur_time = 0
    hold_start = 0
    
    # State history.
    ball_states = []
    
    for _ in range(step_count):
        ball_states.append({
            'state': ball_state,
            'hold': 0 if hold_start is None else 1,
        })
        
        # Physics
        ball_state = transition_matrix(dt) @ ball_state
        cur_time += dt
        
        # Catch/throw
        if hold_start is not None:
            if cur_time - hold_start > 0.5:
                # Throw!!
                hold_start = None
                ball_state[4] = -500
                ball_state[5] = 500
                if ball_state[0] < frame_width / 2:
                    ball_state[1] = 220
                else:
                    ball_state[1] = -220
        else:
            if ball_state[3] >= 400:
                ball_state[1] = 0
                ball_state[4] = 0
                ball_state[5] = 0
                hold_start = cur_time
        
    return ball_states

def make_synthetic_observations(ball_statess):
    p_spurious_observation = 0.1
    observationss = []
    for ball_states in ball_statess:
        # Observe
        real_observations = [observation_matrix @ ball_state['state'] for ball_state in ball_states]
        spurious_observations = []
        if random.uniform(0, 1) < p_spurious_observation:
            spurious_observations.append(
                np.array([frame_width, frame_height]) * np.random.uniform(size=[2]))
        observationss.append(np.random.permutation(real_observations + spurious_observations))
    return observationss

def make_synthetic_data():
    ball1 = make_synthetic_ball(100, 400, 500)
    ball2 = make_synthetic_ball(500, 400, 500)
    ball3 = make_synthetic_ball(100, 400, 420)
    
    ball_statess = []
    for i in range(500):
        ball_states = [ball1[i], ball2[i]]
        if i >= 80:
            ball_states.append(ball3[i - 80])
        ball_statess.append(ball_states)

    observationss = make_synthetic_observations(ball_statess)
    
    return observationss, ball_statess

In [None]:
observationss, ball_states = make_synthetic_data()
xs = [d[0][0] for d in observationss]
ys = [d[0][1] for d in observationss]
plt.scatter(xs, ys)
plt.ylim((480, 0))
plt.show()

In [None]:
def transition(N, particles, dt, p_teleport=0.05):
    """
    particles['continuous'] is [N, 6] float array
    particles['discrete'] [N] int array
      - 0 = freefall
      - 1 = hold
    """
        
    # parameters
    # TODO: make dt invariant
    p_freefall_hold = 0.5
    p_hold_freefall = 0.5
    p_appear = 0.05
    p_disappear = 0.05
    hold_freefall_noise = multivariate_normal([0, 0, -2000, 6000], np.diag([500*500, 100*100, 200*200, 200*200]))
    continuous_noise = multivariate_normal([0, 0, 0, 0, 0, 0], np.diag([10*10, 10*10, 10*10, 10*10, 10*10, 10*10]))
        
    # particles['discrete'] state transitions
    
    sample = np.random.uniform(size=[N])
    freefall_hold = (particles['discrete'] == 0) & (sample < p_freefall_hold)
    hold_freefall = (particles['discrete'] == 1) & (sample < p_hold_freefall)
    
    particles['discrete'][freefall_hold] = 1
    particles['discrete'][hold_freefall] = 0
    
    particles['continuous'][freefall_hold, 1] = 0
    particles['continuous'][freefall_hold, 2] = 0
    particles['continuous'][freefall_hold, 4] = 0
    particles['continuous'][freefall_hold, 5] = 0

    hold_freefall_noise_rvs = hold_freefall_noise.rvs(size=[np.sum(hold_freefall)])
    if len(hold_freefall_noise_rvs.shape) == 1:
        hold_freefall_noise_rvs = hold_freefall_noise_rvs[np.newaxis, :]
    particles['continuous'][hold_freefall, 1] = hold_freefall_noise_rvs[:, 0]
    particles['continuous'][hold_freefall, 2] = hold_freefall_noise_rvs[:, 1]
    particles['continuous'][hold_freefall, 4] = hold_freefall_noise_rvs[:, 2]
    particles['continuous'][hold_freefall, 5] = hold_freefall_noise_rvs[:, 3]
    
    # particles['continuous'] state transitions
    
    particles['continuous'] = (transition_matrix(dt) @ particles['continuous'][:, :, np.newaxis])[:, :, 0]
    particles['continuous'] += continuous_noise.rvs(size=[N])
    
    # appear/disappear state transitions
    
    particles['log_p_exist'] = logsumexp([
        np.log(1 - p_disappear) + particles['log_p_exist'],
        np.log(p_appear) + np.log(1 - np.exp(particles['log_p_exist']))
    ])
    
    # teleportation
    
    N_new = int(p_teleport * N)
    new_discrete = np.full([N_new], 1, dtype=int)
    new_continuous = np.zeros([N_new, 6])
    new_continuous[:, 0] = np.random.uniform(0, frame_width, size=[N_new])
    new_continuous[:, 3] = np.random.uniform(0, frame_height, size=[N_new])
    particles['discrete'] = np.concatenate([particles['discrete'], new_discrete])
    particles['continuous'] = np.concatenate([particles['continuous'], new_continuous])

In [None]:
def resample(N, particles, observation):
    N_before = particles['continuous'].shape[0]
    
    # parameters
    p_obs = 0.9
    p_spurious_obs = 0.1
    observation_noise = multivariate_normal([0, 0], np.diag([20*20, 20*20]))
    
    particle_observations = (observation_matrix @ particles['continuous'][:, :, np.newaxis])[:, :, 0]
    
    if observation is None:
        # liklihood from the case where we failed to observe the ball
        observation_logliklihoods = np.full([N_before], np.log(1 - p_obs))
        logp_evidence_given_not_exists = 0
    else:
        observation_logliklihoods = observation_noise.logpdf(particle_observations - observation) + np.log(p_obs)
        logp_evidence_given_not_exists = np.log(p_spurious_obs) - np.log(frame_height * frame_width)

    
    logp_evidence_given_exists = logsumexp(observation_logliklihoods - np.log(N_before))

    logp_evidence = logsumexp([
        logp_evidence_given_exists + particles['log_p_exist'],
        logp_evidence_given_not_exists + np.log(1 - np.exp(particles['log_p_exist'])) # need preciser??
    ])
    particles['log_p_exist'] = logp_evidence_given_exists + particles['log_p_exist'] - logp_evidence
    
    indices = np.random.choice(
        N_before, N, replace=True,
        p=np.exp(observation_logliklihoods - logsumexp(observation_logliklihoods)))
    
    particles['continuous'] = particles['continuous'][indices, :]
    particles['discrete'] = particles['discrete'][indices]
    
# Returns the log liklihood of `observation` given that the observation comes from the existing ball whose
# distribution is described by `particles`.
#
# "existing ball" means that we ignore particles['log_p_exist'] for the purposes of this calculation.
def logp_observation(particles, observation):
    observation_noise = multivariate_normal([0, 0], np.diag([20*20, 20*20]))
    particle_observations = (observation_matrix @ particles['continuous'][:, :, np.newaxis])[:, :, 0]
    particle_logp_observations = observation_noise.logpdf(particle_observations - observation)
    return logsumexp(particle_logp_observations - np.log(particles['continuous'].shape[0]))

# Returns the log liklihood of `observation` given that it is spurious.
def logp_observation_given_spurious(observation):
    return -np.log(frame_width * frame_height)

In [None]:
def unweighted_resample(N, particles):
    N_before = particles['continuous'].shape[0]
    indices = np.random.choice(N_before, N, replace=True)
    particles['continuous'] = particles['continuous'][indices, :]
    particles['discrete'] = particles['discrete'][indices]

In [None]:
def reject_nearby(particles, position, peak_size=100):
    close_mask = np.linalg.norm(particles['continuous'][:, [0, 3]] - position, axis=1) < peak_size
    particles['continuous'] = particles['continuous'][~close_mask, :]
    particles['discrete'] = particles['discrete'][~close_mask]

In [None]:
def find_peak(positions, peak_size=100, peak_threshold=0.2):
    last_candidate_peak = None
    candidate_peak = positions[np.random.choice(positions.shape[0], 1)]
    
    while last_candidate_peak is None or np.linalg.norm(candidate_peak - last_candidate_peak) > 1.0:
        last_candidate_peak = candidate_peak
        close_mask = np.linalg.norm(positions - candidate_peak, axis=1) < peak_size
        candidate_peak = np.mean(positions[close_mask, :], axis=0)
        
    close_mask = np.linalg.norm(positions - candidate_peak, axis=1) < peak_size
    if np.sum(close_mask) < peak_threshold * positions.shape[0]:
        return None
    return candidate_peak

def find_peaks(positions, peak_size=100, peak_threshold=0.2):
    N = positions.shape[0]
    peaks = []
    for _ in range(10):
        if positions.shape[0] < peak_threshold * N:
            break
        peak = find_peak(positions, peak_size=peak_size, peak_threshold=peak_threshold)
        if peak is None:
            continue
        peaks.append(peak)
        close_mask = np.linalg.norm(positions - peak, axis=1) < peak_size
        positions = positions[~close_mask, :]
    return peaks

def find_biggest_peak(positions, peak_size=100, peak_threshold=0.2):
    peaks = find_peaks(positions, peak_size=peak_size, peak_threshold=peak_threshold)
    
    biggest_size = 0
    biggest_index = None
    for i, peak in enumerate(peaks):
        close_mask = np.linalg.norm(positions - peak, axis=1) < peak_size
        size = np.sum(close_mask)
        if size > biggest_size:
            biggest_size = size
            biggest_index = i
    
    if biggest_index is None:
        return None
    return peaks[biggest_index]

In [None]:
def make_particles(N):
    particles = {
        'discrete': np.full([N], 1, dtype=int),
        'continuous': np.zeros([N, 6]),
        'log_p_exist': np.log(0.01), 
    }
    particles['continuous'][:, 0] = np.random.uniform(0, frame_width, size=[N])
    particles['continuous'][:, 3] = np.random.uniform(0, frame_height, size=[N])
    return particles

In [None]:
def make_state(N):
    return {
        'N': N,
        'new_particles': make_particles(N),
        'identified_particles': [],
    }

In [None]:
def step(state, dt, observations):
    N = state['N']
    
    # Greedily associate each observation with the highest-liklihood explanation for it.
    # Possible explanations are:
    # - spurious observation (-2) (allowed to explain multiple things)
    # - observation of a new ball (-1)
    # - observation of an identified ball (index in `state['identified_particles']`)
    remaining_explanations = set(range(len(state['identified_particles']))).union([-1])
    observation_explanations = []
    for observation in observations:
        best_explanation = -2
        best_logp = logp_observation_given_spurious(observation) + np.log(0.001) #haxxxx
        for explanation in remaining_explanations:
            particles = state['new_particles'] if explanation == -1 else state['identified_particles'][explanation]
            logp = logp_observation(particles, observation)
            if logp > best_logp:
                best_logp = logp
                best_explanation = explanation
        observation_explanations.append(best_explanation)
        remaining_explanations.discard(best_explanation)
        
    new_particles_observation_index = None
    identified_particle_observation_indices = [None] * len(state['identified_particles'])
    for (observation_index, explanation) in enumerate(observation_explanations):
        if explanation == -1:
            new_particles_observation_index = observation_index
        if explanation >= 0:
            identified_particle_observation_indices[explanation] = observation_index
        
    biggest_peaks = []
    for (i, particles) in enumerate(state['identified_particles']):
        observation_index = identified_particle_observation_indices[i]
        observation = None if observation_index is None else observations[observation_index]
        transition(N, particles, 0.1, p_teleport=0.01)
        resample(N, particles, observation)
        positions = particles['continuous'][:, [0, 3]]
        #biggest_peaks.append(np.mean(positions, axis=0))
        peak = find_biggest_peak(positions)
        if peak is not None:
            biggest_peaks.append(peak)
        
    transition(N, state['new_particles'], 0.1, p_teleport=0.1)
    for peak in biggest_peaks:
        reject_nearby(state['new_particles'], peak)
    new_particles_observation = None if new_particles_observation_index is None else observations[new_particles_observation_index]
    resample(N, state['new_particles'], new_particles_observation)
    
    if state['new_particles']['log_p_exist'] > -5e-4:
        positions = state['new_particles']['continuous'][:, [0, 3]]
        peak = find_biggest_peak(positions)
        if peak is not None:
            peak_size = 100
            close_mask = np.linalg.norm(positions - peak, axis=1) < peak_size
            state['new_particles']['continuous'] = state['new_particles']['continuous'][close_mask, :]
            state['new_particles']['discrete'] = state['new_particles']['discrete'][close_mask]
            unweighted_resample(N, state['new_particles'])
            state['identified_particles'].append(state['new_particles'])
            state['new_particles'] = make_particles(N)
            
    state['identified_particles'] = [p for p in state['identified_particles'] if p['log_p_exist'] > -4]

In [None]:
N = 20_000
state = make_state(N)
for time_index in range(10):
    observations = observationss[time_index]
    step(state, 0.1, observations)
    
    print("new particle log_p_exist", state['new_particles']['log_p_exist'])
    
    if time_index > 100:
        plt.xlim([0, frame_width])
        plt.ylim([frame_height, 0])
        for p in state['identified_particles']:
            plt.scatter(p['continuous'][:, [0]], p['continuous'][:, [3]])
        plt.show()

# Run on some images!!

In [None]:
import torch
from PIL import Image
import PIL.ImageOps
import numpy as np

import matplotlib.pyplot as plt

import cv2

import train

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

In [None]:
net = train.UNet()
net.load_state_dict(torch.load('net03'))
net.train(False)
device = torch.device('cuda')
net = net.to(device=device)

In [None]:
def find_balls(pred):
    balls = []

    while True:
        a = torch.nn.functional.conv2d(pred.unsqueeze(0).unsqueeze(0), torch.ones(1, 1, 20, 20).to(device=device)).squeeze(0).squeeze(0)
        if torch.max(a) < 0.6 * 20 * 20:
            break
        max_index = torch.argmax(a)
        max_x = max_index % a.shape[1] + 10
        max_y = max_index // a.shape[1] + 10
        pred[max_y-20:max_y+20, max_x-20:max_x+20] = 0
        balls.append(np.array([max_x.item(), max_y.item()]))

    return balls

In [None]:
observationss = []
for i in range(200):
    with torch.no_grad():
        img = Image.open('data/cap3/img/%03d.png' % i)
        img = torch.from_numpy(np.array(img).transpose((2, 0, 1))).type(torch.FloatTensor).to(device=device)
        pred = net(img.unsqueeze(0)).squeeze(0).squeeze(0)
        pred = sigmoid(pred)
    observationss.append([np.array(x) for x in find_balls(pred)])

In [None]:
N = 20_000
state = make_state(N)
for time_index in range(200):
    observations = observationss[time_index]
    step(state, 0.1, observations)
        
    plt.xlim([0, frame_width])
    plt.ylim([frame_height, 0])
    img = Image.open('data/cap3/img/%03d.png' % time_index)
    plt.imshow(img)
    for p in state['identified_particles']:
        plt.scatter(p['continuous'][:, [0]], p['continuous'][:, [3]])
    plt.show()

In [None]:
import time

cap = cv2.VideoCapture(0)
xres = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
yres = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

cur_time = time.time()

N = 20_000
state = make_state(N)

while True:
    ret, frame = cap.read()
    if not ret: break
        
    next_time = time.time()
    dt = next_time - cur_time
    cur_time = next_time
    print(dt)
    
    with torch.no_grad():
        cv2.imwrite('tmp.png', frame)
        img = Image.open('tmp.png')
        img = torch.from_numpy(np.array(img).transpose((2, 0, 1))).type(torch.FloatTensor).to(device=device)
        pred = net(img.unsqueeze(0)).squeeze(0).squeeze(0)
        pred = sigmoid(pred)
        
    balls = find_balls(pred)
    step(state, dt, balls)
    
    for ball in balls:
        cv2.circle(frame, (ball[0], ball[1]), 10, (0, 0, 255))

    print(state['new_particles']['log_p_exist'])
    for p in state['identified_particles']:
        print("p_hold: ", np.mean(p['discrete']))

    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
    for i, p in enumerate(state['identified_particles']):
        color = colors[i % len(colors)]
        for bs in p['continuous']:
            cv2.circle(frame, (int(bs[0]), int(bs[3])), 4, color)
    
    cv2.imshow('frame', frame)
    cv2.imshow('pred', pred.cpu().numpy())
    if cv2.waitKey(1) & 0xFF == ord('q'): break
        
    print()
        
cv2.destroyAllWindows()

In [None]:
logsumexp([1, 2])