In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from synthpop.models.flocking_model import FlockingModel

INFO:FlockingModel:Loading julia...
  ActivatingINFO:FlockingModel:Pre-compiling FlockingModel...
 project at `~/code/synthpop/synthpop/models/FlockingModel`


In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from synthpop.infer import Infer, VI
from synthpop.generative import MaskedAutoRegressiveFlow
import pygtc

  from .autonotebook import tqdm as notebook_tqdm


In [43]:

n = 25 # Number of birds
k = 5 # Number of factors
speed = 3.0 # Bird speed is fixed
avoid_factor = 5.0 # Avoid factor is fixed
avoid_radius = 10.0 # Avoid radius is fixed
sep_factor = 0.2 # Separation factor is fixed
sep_radius = 15.0 # Separation factor is fixed
n_timesteps = 1000
n_parameters = 16
model = FlockingModel(n_agents=n, k=k, n_timesteps=n_timesteps)

In [44]:

def loss(x):
    x = x[0]
    ret = 0.0
    n_agents = x.shape[0]
    radii = np.sqrt(x[:, 0, :]**2 + x[:, 1, :]**2)
    error = np.std(radii, axis=1)
    return torch.tensor(np.mean(error))

class MetaGenerator(MaskedAutoRegressiveFlow):
    def forward(self, generator_params=None):
        if generator_params is None:
            with torch.no_grad():
                generator_params = self.sample(1)[0][0]
        def generator(n_agents):
            params = generator_params.numpy()
            # position is randomly distributed
            pos = 50 * np.random.uniform(params[0], params[1], (n_agents, 2))
            # beta distributed values
            beta_alphas = np.exp(params[[2, 3, 4, 5]])
            beta_betas = np.exp(params[[6, 7, 8, 9]])
            # angle is between 0 and 2pi
            orientation = 2 * np.pi * np.random.beta(beta_alphas[0], beta_betas[0], (n_agents,))
            speeds = speed * np.ones((n_agents,))
            vel = speed * np.stack([np.cos(orientation), np.sin(orientation)], axis=1)
            # factors
            stuborn_factors = np.random.beta(beta_alphas[1], beta_betas[1], (n_agents,))
            cohere_factors = np.random.beta(beta_alphas[2], beta_betas[2], (n_agents,))
            sep_factors = sep_factor * np.ones((n_agents,))
            match_factors = np.random.beta(beta_alphas[3], beta_betas[3], (n_agents,))
            avoid_factors = avoid_factor * np.ones((n_agents,))
            factors = np.stack([stuborn_factors, cohere_factors, sep_factors, match_factors, avoid_factors], axis=0)
            # radii
            # lognormally distributed
            lognormal_mu = params[[10, 11, 12]]
            lognormal_sigma = np.exp(params[[13, 14, 15]])
            stuborn_radii = np.random.lognormal(lognormal_mu[0], lognormal_sigma[0], (n_agents,))
            cohere_radii = np.random.lognormal(lognormal_mu[1], lognormal_sigma[1], (n_agents,))
            sep_radii = sep_radius * np.ones((n_agents,))
            match_radii = np.random.lognormal(lognormal_mu[2], lognormal_sigma[2], (n_agents,))
            avoid_radii = avoid_radius * np.ones((n_agents,))
            radii = np.stack([stuborn_radii, cohere_radii, sep_radii, match_radii, avoid_radii], axis=0)
            return pos, vel, speeds, factors, radii
        return generator


In [50]:
meta_generator = MetaGenerator(n_parameters = n_parameters, n_transforms=8, n_hidden_units=64)
inff = 1000
prior_vi = torch.distributions.Independent(torch.distributions.Uniform(-inff * torch.ones(n_parameters), inff * torch.ones(n_parameters)), 1)
infer = Infer(model=model, meta_generator=meta_generator, prior=prior_vi, loss=loss)
optimizer = torch.optim.AdamW(meta_generator.parameters(), lr=3e-3)
infer_method = VI(w=0.1, n_samples_per_epoch=10, optimizer=optimizer, 
                    progress_bar=True, progress_info=True, gradient_estimation_method="score", log_tensorboard=True)
meta_generator = infer.fit(infer_method, n_epochs=1000, max_epochs_without_improvement=50)
meta_generator.load_state_dict(torch.load("./best_estimator.pt"))


  9%|▉         | 90/1000 [01:34<15:28,  1.02s/it, loss=31.81, reg=11.47, total=43.28, best loss=23.52, epochs since improv.=50] INFO:vi:Stopping early because the loss did not improve for 50 epochs.
  9%|▉         | 90/1000 [01:34<15:57,  1.05s/it, loss=31.81, reg=11.47, total=43.28, best loss=23.52, epochs since improv.=50]


<All keys matched successfully>

In [51]:
meta_generator.load_state_dict(torch.load("./best_estimator.pt"))

<All keys matched successfully>

In [52]:
# plot bird trajectories
generator = meta_generator()
x = model(generator)
positions, velocities = x

In [53]:
np.save("positions.npy", positions)
np.save("velocities.npy", velocities)