In [1]:
from scipy.stats import truncnorm
from matplotlib import pyplot as plt
from metrics import *
from inference_utils import *
from systems import SpatialSIR3D

In [2]:
from viaABC import viaABC
from scipy.stats import uniform, truncnorm
import numpy as np
from typing import Union, List
from systems import *
from metrics import *
from scipy.ndimage import convolve
from scipy.integrate import solve_ivp
from functools import lru_cache
import numba

In [3]:
@numba.jit(nopython=True, cache=True)
def update_timers_and_states(grid, infection_timer, recovery_timer, susceptible_timer, 
                            dt, I, R, SUSCEPTIBLE, INFECTED, RECOVERED):
    """Optimized timer updates and state transitions using Numba JIT compilation."""
    rows, cols = grid.shape
    
    for i in range(rows):
        for j in range(cols):
            if grid[i, j] == INFECTED:
                infection_timer[i, j] += dt
                if infection_timer[i, j] >= I:
                    grid[i, j] = RECOVERED
                    susceptible_timer[i, j] = 0
                    recovery_timer[i, j] = 0
                    infection_timer[i, j] = 0
                
            elif grid[i, j] == RECOVERED:
                recovery_timer[i, j] += dt
                if recovery_timer[i, j] >= R:
                    grid[i, j] = SUSCEPTIBLE
                    infection_timer[i, j] = 0
                    susceptible_timer[i, j] = 0
                    recovery_timer[i, j] = 0
                    
            elif grid[i, j] == SUSCEPTIBLE:
                susceptible_timer[i, j] += dt

In [None]:
class OptimizedSpatialSIR3D(viaABC):
    def __init__(self,
        num_parameters = 2, 
        mu = np.array( [0.2, 0.2]), # Lower Bound
        sigma = np.array([4.5, 4.5]),
        model = None, 
        observational_data = None,
        state0 = None,
        t0 = 0,
        tmax = 16,
        interval = 1,
        time_space = np.arange(1, 16, 1),
        pooling_method = "no_cls",
        metric = "pairwise_cosine",
        grid_size = 80,
        initial_infected = 5,
        radius = 5):

        observational_data = self.labels2map(np.load('/home/jp4474/viaABC/data/SPATIAL/data.npy'))
        print(observational_data.shape)

        super().__init__(num_parameters, mu, sigma, observational_data, model, state0, t0, tmax, time_space, pooling_method, metric)
        self.grid_size = grid_size
        self.initial_infected = initial_infected
        self.radius = radius
        self.time_steps = int((tmax - t0)/interval)
        self.lower_bounds = mu
        self.upper_bounds = sigma
    
    def sample_priors(self):
        # Sample from the prior distribution
        priors = np.random.uniform(self.lower_bounds, self.upper_bounds, self.num_parameters)
        return priors
            
    def calculate_prior_log_prob(self, parameters):
        # Calculate the prior log probability of the parameters
        # This must match the prior distribution used in sampling
        log_probabilities = uniform.logpdf(parameters, loc=self.lower_bounds, scale=self.upper_bounds - self.lower_bounds) 
        return np.sum(log_probabilities)

    def labels2map(self, y):
        susceptible = (y == 0)
        infected = (y == 1)
        resistant = (y == 2)

        y_onehot = np.stack([susceptible, infected, resistant], axis=1)  # Shape: (3, H, W)

        return y_onehot
    
    def preprocess(self, x):
        # add a channel dimension at the beginning in numpy
        if x.shape[0] == 15:
            x = x.transpose(1, 0, 2, 3)

        if x.ndim == 4:
            x = np.expand_dims(x, axis=0)

        return x

    def simulate(self, parameters: np.ndarray):
        """Optimized version of the SIR simulation with multiple performance improvements."""
        SUSCEPTIBLE, INFECTED, RECOVERED = 0, 1, 2
        
        beta, tau_I = parameters
        dt = 0.01
        I = tau_I
        R = 1.0
        steps = int(np.round(np.max(self.time_space) / dt))
        
        # Pre-allocate arrays with correct dtypes
        grid = np.zeros((self.grid_size, self.grid_size), dtype=np.uint8)
        grid_shape = grid.shape
        infection_timer = np.zeros(grid_shape, dtype=np.float32)
        recovery_timer = np.zeros(grid_shape, dtype=np.float32)
        susceptible_timer = np.zeros(grid_shape, dtype=np.float32)
        
        # Initialize infection centers
        centers = np.array([[44, 67], [24, 67], [64, 73], [3, 55], [12, 20]])
        for x, y in centers:
            dx, dy = np.random.randint(-self.radius, self.radius + 1, 2)
            xi, yi = np.clip([x + dx, y + dy], 0, self.grid_size - 1)
            grid[xi, yi] = INFECTED
        
        # Pre-compute convolution kernel
        kernel = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=np.uint8)
        
        # Pre-compute which frames to save
        frames_idx = (self.time_space / dt).astype(int) - 1
        frames_to_save = set(frames_idx)
        
        # Pre-allocate output array
        output_frames = np.zeros((len(self.time_space), self.grid_size, self.grid_size, 3), 
                            dtype=np.float32)
        saved_frame_count = 0
        
        # Pre-compute exponential for infection probability
        neg_beta_dt = -beta * dt
        
        # Pre-allocate arrays for vectorized operations
        infected_mask = np.zeros(grid_shape, dtype=bool)
        susceptible_mask = np.zeros(grid_shape, dtype=bool)
        rand_vals = np.zeros(grid_shape, dtype=np.float32)
        
        for t in range(steps):
            # Vectorized neighbor counting
            infected_mask[:] = (grid == INFECTED)
            infected_neighbors = convolve(infected_mask.astype(np.uint8), kernel, 
                                        mode='constant', cval=0)
            
            # Vectorized infection probability calculation
            rand_vals[:] = np.random.rand(*grid_shape)
            susceptible_mask[:] = (grid == SUSCEPTIBLE)
            
            # Only calculate infection probability where there are susceptible cells and neighbors
            infection_candidates = susceptible_mask & (infected_neighbors > 0)
            
            if np.any(infection_candidates):
                # Calculate infection probabilities for candidates
                candidate_neighbors = infected_neighbors[infection_candidates]
                p_inf = 1 - np.exp(neg_beta_dt * candidate_neighbors)
                
                # Get random values for infection candidates
                candidate_rand = rand_vals[infection_candidates]
                
                # Determine which candidates get infected
                will_be_infected = candidate_rand < p_inf
                
                # Create a mask for the full grid showing new infections
                new_infection_mask = np.zeros(grid_shape, dtype=bool)
                candidate_coords = np.where(infection_candidates)
                infected_candidate_indices = np.where(will_be_infected)[0]
                
                if len(infected_candidate_indices) > 0:
                    new_infection_rows = candidate_coords[0][infected_candidate_indices]
                    new_infection_cols = candidate_coords[1][infected_candidate_indices]
                    new_infection_mask[new_infection_rows, new_infection_cols] = True
                    
                    # Apply infections
                    grid[new_infection_mask] = INFECTED
                    recovery_timer[new_infection_mask] = 0
                    infection_timer[new_infection_mask] = 0
                    susceptible_timer[new_infection_mask] = 0
            
            # Use optimized timer updates (now as external function)
            update_timers_and_states(grid, infection_timer, recovery_timer, susceptible_timer,
                                          dt, I, R, SUSCEPTIBLE, INFECTED, RECOVERED)
            
            # Save frame only if needed
            if t in frames_to_save:
                # Create one-hot encoding efficiently
                output_frames[saved_frame_count, :, :, 0] = (grid == SUSCEPTIBLE)
                output_frames[saved_frame_count, :, :, 1] = (grid == INFECTED)
                output_frames[saved_frame_count, :, :, 2] = (grid == RECOVERED)
                saved_frame_count += 1
        
        # Transpose to match expected output format: (3, time, height, width)
        output = output_frames.transpose(3, 0, 1, 2)
        return output, 0

In [13]:
spatial = SpatialSIR3D()
optimized = OptimizedSpatialSIR3D()

INFO:viaABC:Initializing viaABC class
The class can be initialized without a model, but it will not
be able to run the algorithm.
INFO:viaABC:viaABC class initialized with the following parameters:
INFO:viaABC:num_parameters: 2
INFO:viaABC:Mu: [0.2 0.2]
INFO:viaABC:Sigma: [4.5 4.5]
INFO:viaABC:t0: 0
INFO:viaABC:tmax: 16
INFO:viaABC:time_space: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
INFO:viaABC:pooling_method: no_cls
INFO:viaABC:metric: pairwise_cosine
INFO:viaABC:Initializing viaABC class
The class can be initialized without a model, but it will not
be able to run the algorithm.
INFO:viaABC:viaABC class initialized with the following parameters:
INFO:viaABC:num_parameters: 2
INFO:viaABC:Mu: [0.2 0.2]
INFO:viaABC:Sigma: [4.5 4.5]
INFO:viaABC:t0: 0
INFO:viaABC:tmax: 16
INFO:viaABC:time_space: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
INFO:viaABC:pooling_method: no_cls
INFO:viaABC:metric: pairwise_cosine


(15, 3, 80, 80)
(15, 3, 80, 80)


In [15]:
# time it over 1000
%timeit spatial.simulate(np.array([0.5, 2.0]))

57 ms ± 375 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%timeit optimized.simulate(np.array([0.5, 2.0]))

39.9 ms ± 507 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
