In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from synthpop.models.flocking_model import FlockingModel

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from synthpop.infer import Infer, VI, SMCABC

from synthpop.generative import MaskedAutoRegressiveFlow, SampleGenerator, DiracDelta
import pygtc

In [None]:
# Model parameters
n_agents = 20 # Number of birds
k = 5 # Number of factors
speed = 3.0 # Bird speed is fixed
avoid_factor = 5.0 # Avoid factor is fixed
avoid_radius = 10.0 # Avoid radius is fixed
sep_factor = 0.2 # Separation factor is fixed
sep_radius = 15.0 # Separation factor is fixed
n_timesteps = 500
n_parameters = 18
model = FlockingModel(n_agents=n_agents, k=k, n_timesteps=n_timesteps)

In [None]:
# define loss
def loss_circle_std(x):
    x = x[0]
    radii = np.sqrt(x[:, 0, :]**2 + x[:, 1, :]**2)
    radii_std = np.std(radii, axis=1)
    assert radii_std.shape == (n_agents,)
    return np.mean(radii_std)

def loss_circle_hausdorff(x):
    x = torch.tensor(x[0], dtype=torch.float).reshape(n_agents, n_timesteps, 2)
    t = torch.linspace(0, 2 * torch.pi, 100)
    r = 100
    circle = torch.stack([r * torch.cos(t), r * torch.sin(t)], axis=1).reshape(1, 100, 2)
    # claculate pairwise dsitance between each bird trajectory and the circle
    pw = torch.cdist(x, circle)
    # compute hausdorff distance
    return torch.max(torch.max(torch.min(pw, dim=0)[0]), torch.max(torch.min(pw, dim=1)[0]))

def make_generator(generator_params):
    def generator(n_agents):
        params = iter(generator_params.numpy())
        # position is randomly distributed
        pos_x = 250 * np.random.uniform(next(params), next(params), (n_agents,))
        pos_y = 250 * np.random.uniform(next(params), next(params), (n_agents,))
        pos = np.stack([pos_x, pos_y], axis=1)
        # angle is between 0 and 2pi
        orientation = 2 * np.pi * np.random.beta(np.exp(next(params)), np.exp(next(params)), (n_agents,))
        vel = np.stack([np.cos(orientation), np.sin(orientation)], axis=1)
        speeds = speed * np.ones((n_agents,))
        # factors
        stuborn_factors = np.random.lognormal(next(params), np.exp(next(params)), (n_agents,))
        cohere_factors = np.random.lognormal(next(params), np.exp(next(params)), (n_agents,))
        ##sep_factors = np.random.normal(beta_alphas[3], beta_betas[3], (n_agents,))
        sep_factors = sep_factor * np.ones((n_agents,))
        match_factors = np.random.lognormal(next(params), np.exp(next(params)), (n_agents,))
        #avoid_factors = np.random.normal(5 + beta_alphas[5], beta_betas[5], (n_agents,))
        avoid_factors = avoid_factor * np.ones((n_agents,))
        factors = np.stack([stuborn_factors, cohere_factors, sep_factors, match_factors, avoid_factors], axis=0)
        #factors = np.clip(factors, a_min=0, a_max=None)
        # radii
        # lognormally distributed
        lognormal_mu = [next(params) for i in range(3)] 
        lognormal_sigma = np.exp([next(params) for i in range(3)])
        stuborn_radii = np.random.lognormal(lognormal_mu[0], lognormal_sigma[0], (n_agents,))
        cohere_radii = np.random.lognormal(lognormal_mu[1], lognormal_sigma[1], (n_agents,))
        #sep_radii = np.random.lognormal(15 + lognormal_mu[2], lognormal_sigma[2], (n_agents,))
        sep_radii = sep_radius * np.ones((n_agents,))
        match_radii = np.random.lognormal(lognormal_mu[2], lognormal_sigma[2], (n_agents,))
        #avoid_radii = np.random.lognormal(10 + lognormal_mu[4], lognormal_sigma[4], (n_agents,))
        avoid_radii = avoid_radius * np.ones((n_agents,))
        radii = np.stack([stuborn_radii, cohere_radii, sep_radii, match_radii, avoid_radii], axis=0)
        return pos, vel, speeds, factors, radii
    return generator


# 1. SMCABC

In [None]:

class SampleMetaGenerator(SampleGenerator):
    def forward(self, generator_params):
        return make_generator(generator_params)

prior = torch.distributions.Independent(torch.distributions.Uniform(torch.zeros(n_parameters), torch.ones(n_parameters)), reinterpreted_batch_ndims=1)

In [None]:
sample_meta_generator = SampleMetaGenerator()
generator = sample_meta_generator(prior.sample())
x = model(generator)
positions, velocities = x
model.plot(positions, velocities, "../figures/birds/birds_smcabc_prior.gif", plot_lim=500)

In [None]:
infer = Infer(model=model, meta_generator=sample_meta_generator, prior=prior, loss=loss_circle)
infer_method = SMCABC(num_particles = 100, num_initial_pop=1_000, num_simulations=1_000, epsilon_decay=0.6)
trained_meta_generator = infer.fit(infer_method, num_workers=1)

In [None]:
generator = trained_meta_generator()
x = model(generator)
positions, velocities = x
model.plot(positions, velocities, "../figures/birds/birds_smcabc_trained.gif")

# 2. Variational Inference

In [None]:

class FlowMetaGenerator(MaskedAutoRegressiveFlow):
    def forward(self, generator_params=None):
        if generator_params is None:
            with torch.no_grad():
                generator_params = self.sample(1)[0][0]
        return make_generator(generator_params)

In [None]:
flow_meta_generator = FlowMetaGenerator(n_parameters=n_parameters, n_hidden_units=32, n_transforms=8)

In [None]:

generator = flow_meta_generator()
x = model(generator)
positions, velocities = x
model.plot(positions, velocities, "../figures/birds/vi_prior.gif")

In [None]:
inff = 1000
prior_vi = torch.distributions.Independent(torch.distributions.Uniform(-inff * torch.ones(n_parameters), inff * torch.ones(n_parameters)), 1)
infer = Infer(model=model, meta_generator=flow_meta_generator, prior=prior_vi, loss=loss_circle_std)
optimizer = torch.optim.AdamW(flow_meta_generator.parameters(), lr=1e-3)
infer_method = VI(w=0., n_samples_per_epoch=25, optimizer=optimizer, 
                    progress_bar=True, progress_info=True, gradient_estimation_method="score", log_tensorboard=True)
flow_meta_generator = infer.fit(infer_method, n_epochs=1000, max_epochs_without_improvement=50)
flow_meta_generator.load_state_dict(torch.load("./best_estimator.pt"))

In [None]:
flow_meta_generator.load_state_dict(torch.load("./best_estimator.pt"))
generator = flow_meta_generator()
x = model(generator)
positions, velocities = x
model.plot(positions, velocities, "../figures/birds/vi_trained.gif")

In [None]:
loss_circle(x)