In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf
from copy import deepcopy
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm, animation, colors
from importlib import import_module
from time import time

from stochnet_v2.dataset.dataset import HDF5Dataset
from stochnet_v2.static_classes.model import StochNet
from stochnet_v2.dynamic_classes.model import NASStochNet
from stochnet_v2.utils.file_organisation import ProjectFileExplorer
from stochnet_v2.utils.util import generate_gillespy_traces, plot_random_traces, maybe_create_dir
from stochnet_v2.utils.util import merge_species_and_param_settings
from stochnet_v2.static_classes.grid_runner import *

%load_ext autoreload
%autoreload 2

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
np.set_printoptions(suppress=True, precision=2, linewidth=120)

tf.test.is_gpu_available()

In [None]:
model_name = 'Bees'
timestep = 0.5
endtime = 100.0
dataset_id = 1
model_id = 1001
nb_features = 4
nb_past_timesteps = 1
params_to_randomize = []

In [None]:
project_folder = '/home/dn/DATA/PARAMETERIZED/' + model_name

# project_folder = '/home/dn/DATA/' + model_name
# project_explorer = ProjectFileExplorer(project_folder)
# dataset_explorer = project_explorer.get_dataset_file_explorer(timestep, dataset_id)
# model_explorer = project_explorer.get_model_file_explorer(timestep, model_id)

In [None]:
CRN_module = import_module("stochnet_v2.CRN_models." + model_name)
CRN_class = getattr(CRN_module, model_name)
m = CRN_class(endtime, timestep)

### Choose model

In [None]:
# model = nn
model = Model(m, params_to_randomize)

## Initialize GridRunner

In [None]:
w_dir = '/home/dn/DATA/GRID_RUNNER/' + model_name
maybe_create_dir(w_dir)

grid_spec = GridSpec(
    boundaries=[[0.0, 1.0], [0.0, 1.0]],
    grid_size=[10, 10]
)

gr = GridRunner(
    model,
    grid_spec,
    w_dir,
    diffusion_kernel_size=3,
    diffusion_sigma=0.7
)

gr.grid.shape

### Set custom diffusion kernel

In [None]:
kernel = np.array(
    [[0.8, 0.8, 0.8],
     [0.8, 1.0, 0.8],
     [0.8, 0.8, 0.8]])
kernel = np.expand_dims(kernel, -1)
gr.diffusion_kernel = kernel

### Set initial state

In [None]:
n_settings = 10

initial_settings = m.get_initial_settings(n_settings)
randomized_params = m.get_randomized_parameters(params_to_randomize, n_settings)

settings = merge_species_and_param_settings(initial_settings, randomized_params)
settings

#### or

#### or

In [None]:
gr.clear_state(mode='all')
gr.set_state(settings[..., model.nb_features:][0], mode='params')
gr.set_state([10, 0, 0, 0], None, mode='species')
gr.set_state([10, 0, 0, 1], (8, 9), mode='species')
gr.set_state([10, 0, 0, 5], (9, 9), mode='species')
gr.set_state([10, 0, 0, 1], (9, 8), mode='species')

In [None]:
[gr.state[..., i] for i in range(gr.state.shape[-1])]

### Single diffusion step

In [None]:
for _ in range(1):
    gr.diffusion_step(
        species_idx=3,
        conservation=False,
    )

### Single max_propagation step

In [None]:
for _ in range(1):
    gr.max_propagation_step(
        species_idx=3,
        alpha=0.5,
    )

### Model steps

In [None]:
for _ in range(3):
    start = time()
    gr.model_step()
    elapsed = time() - start
    print(f'.. elapsed {elapsed:.2f}')

### Dsplay state

In [None]:
n = gr.model.nb_features

fig, axes = plt.subplots(1, n, figsize=(n * 4, 4))
for i in range(n):
    ax = axes[i]
    ax.imshow(gr.state[..., i])
    ax.set_xlim([0, 10])
    ax.set_ylim([10, 0])
    pcm = ax.pcolormesh(gr.state[..., i], cmap=None)
    fig.colorbar(pcm, ax=ax, shrink=0.7)
plt.tight_layout()

In [None]:
gr.state[..., 3]

In [None]:
states = gr.run_model(100, 1, 5, propagation_mode='mp', species_idx=3, alpha=0.33)

In [None]:
np.save(os.path.join(gr.save_dir, 'states'), states)

## Animated figure

## Special animation for Bees: 
sum 0th (normal bee) and 1st (aggressive bee) species to get population of alive bees

In [None]:
n = 3
names = gr.model.model.get_species_names()
cmap = 'viridis'

fig, axes = plt.subplots(1, n, figsize=(n * 4, 4))

all_images = []
for species_idx in range(n):
    
    ax = axes[species_idx]
    ax.set_title(names[species_idx+1] if species_idx > 0 else 'Alive Bees')
    images = []

    for state in states:
        im = ax.imshow(state[..., species_idx+1]
                       if species_idx > 0
                       else state[..., 0] + state[..., 1], animated=True)
        images.append(im)

    vmin = min(image.get_array().min() for image in images)
    vmax = max(image.get_array().max() for image in images)
    # vmax = 140.
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    for im in images:
        im.set_norm(norm)

    fig.colorbar(images[0], ax=ax, shrink=0.75)

    def update(changed_image):
        for im in images:
            if (changed_image.get_cmap() != im.get_cmap()
                    or changed_image.get_clim() != im.get_clim()):
                im.set_cmap(changed_image.get_cmap())
                im.set_clim(changed_image.get_clim())

    for im in images:
        im.callbacksSM.connect('changed', update)
    all_images.append(images)

plt.tight_layout()

## Make a GIF

In [None]:
a = [[images[i] for images in all_images] for i in range(len(all_images[0]))]
ani = animation.ArtistAnimation(fig, a, interval=100, blit=True, repeat_delay=2000)
ani.save(os.path.join(gr.save_dir, f'animated_progress.gif'))

### Separate animation for every species

In [None]:
for species_idx in range(gr.model.nb_features):

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    cmap = 'viridis'
    images = []

    for state in states:
        im = ax.imshow(state[..., species_idx], animated=True)
        images.append(im)

    vmin = min(image.get_array().min() for image in images)
    vmax = max(image.get_array().max() for image in images)
    # vmax = 140.
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    for im in images:
        im.set_norm(norm)

    fig.colorbar(images[0], ax=ax, shrink=0.75)

    def update(changed_image):
        for im in images:
            if (changed_image.get_cmap() != im.get_cmap()
                    or changed_image.get_clim() != im.get_clim()):
                im.set_cmap(changed_image.get_cmap())
                im.set_clim(changed_image.get_clim())

    for im in images:
        im.callbacksSM.connect('changed', update)

    plt.tight_layout()

    ani = animation.ArtistAnimation(fig, [[im] for im in images], interval=200, blit=True, repeat_delay=2000)
    ani.save(os.path.join(gr.save_dir, f'nn_animated_progress_{species_idx}.gif'))