In [9]:
import numpy as np
from scipy.integrate import solve_ivp
import synthetic_data
# import tensorflow as tf
import torch

print("done importing")

class TrapEnvironment:
    def __init__(self, num_traps=20,max_steps=100):
        # Initialize environment parameters
        self.num_traps = num_traps
        self.n = 25
        self.m = 25
        self.predator_density = None
        self.prey_density = None
        self.trap_replacement_rate = 10
        self.pts_per_sec = 100
        self.len_traj=50
        self.current_step = 0
        self.max_steps = max_steps
        self.reset()

    def reset(self):
        # Reset the environment to its initial state (just predator and prey)
        # Generate initial predator and prey densities
        # default params for initial generation num_traj=200,len_traj=50, pts_per_sec=100, save_loc='../Data/val.npy', prey_range=(1, 5), predator_range=(1, 3)
        data_init = synthetic_data.generate().reshape(self.n, self.m, 2, self.pts_per_sec*self.len_traj)
        self.prey_density, self.predator_density = data_init[:, :, 0, -1], data_init[:, :, 1, -1]
        return self.predator_density, self.prey_density

    def step(self, action):
        # Take action (place traps) and observe the next state and reward
        # Update predator and prey densities based on the action
        # Calculate reward based on the change in predator density
        # Return next state, reward, and done flag

        # Simulate predator dynamics with traps placed at specified locations
        # Here, action is a list of trap locations [(i1, j1), (i2, j2), ..., (in, jn)]
        y0 = np.zeros((self.n, self.m, 2))
        y0[:,:,0] = self.predator_density
        prey_data = self.prey_density
        for i, j in action:
            y0[i,j,1] = 10  # place those traps at each cell
        y0 = y0.flatten()

        # get impact on predator and prey spread after placement window steps
        master_sol = np.ndarray((self.n*self.m*2,self.pts_per_sec))
        
        for _ in range(self.trap_replacement_rate):
            # trap solver, only grab single timestep
            sol = solve_ivp(synthetic_data.spatial_dynamics_traps, y0=y0, t_span=[0,1], t_eval=np.linspace(0, 1, self.pts_per_sec), args=(self.n, self.m))
            # prey solver y0 creation
            y_prey = np.zeros((self.n, self.m, 2))
            last_dim = int(self.pts_per_sec)
            sol_use = sol.y.reshape((self.n, self.m, 2, last_dim))
            pred_data_new, trap_data_new = sol_use[:, :, 0, :], sol_use[:, :, 1, :]
            
            # set prey from timestep of interest as predator in y_prey
            y_prey[:,:,1] = pred_data_new[:, :, -1]
            # grab predator information from prey data
            y_prey[:,:,0] = prey_data
            print(y_prey.shape, "yellow")
            y_prey = y_prey.flatten()
            # prey solver, only grab single timestep
            sol_prey = solve_ivp(synthetic_data.spatial_dynamics, y0=y_prey, t_span=[0,1], t_eval=np.linspace(0, 1, self.pts_per_sec), args=(self.n, self.m))
            # create y0 for next run of trap solver, overwrite y0 and prey_data
            y0 = np.zeros((self.n, self.m, 2))
            sol_prey_use = sol_prey.y.reshape((self.n, self.m, 2, last_dim))
            prey_data, predator_data = sol_prey_use[:, :, 0, :], sol_prey_use[:, :, 1, :]
            y0[:,:,0] = predator_data[:, :, -1]
            prey_data = prey_data[:, :, -1]
            # initialize trap locations based on number of desired traps and density, re initialize per replacement time
            y0[:,:,1] = trap_data_new[:,:,-1]
            y0 = y0.flatten()
            master_sol = np.concatenate((master_sol, sol_prey.y), 1)

        master_sol = master_sol[:,100:].reshape(self.n, self.m,2,self.pts_per_sec*self.trap_replacement_rate)
        
        # Extract the last step predator and prey densities from the solution
        self.prey_density , self.predator_density = master_sol[:, :, 0, -1], master_sol[:, :, 1, -1]

        # Calculate reward based on the change in predator density
        reward = -np.sum(self.predator_density)

        # Check termination condition
        predator_sum_zero = np.sum(self.predator_density) == 0
        # Check if the maximum number of steps is reached
        max_steps_reached = self.current_step >= self.max_steps
    
        # Combine termination conditions
        done = predator_sum_zero or max_steps_reached
        
        # Return next state, reward, and done flag (assuming no termination condition for now)
        self.current_step += 1
        print(f"finished step {self.current_step}")
        return self.predator_density, self.prey_density, reward, done

# define DQN
class QNetwork(torch.nn.Module):
    def __init__(self, num_actions, input_size):
        super(QNetwork, self).__init__()
        self.dense1 = torch.nn.Linear(input_size, 64)
        self.dense2 = torch.nn.Linear(64, 64)
        self.output_layer = torch.nn.Linear(64, num_actions)


    def forward(self, state):
        x = torch.nn.functional.relu(self.dense1(state))
        x = torch.nn.functional.relu(self.dense2(x))
        return self.output_layer(x)

# Simple Q-learning algorithm with experience replay
class QLearningAgent:
    def __init__(self, m,n,input_size,num_traps):
        self.num_actions = m * n
        self.q_network = QNetwork(self.num_actions,input_size)
        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=0.001)
        self.memory = []
        self.m = m
        self.n = n
        self.num_traps = num_traps

    def select_action(self, state):
        # Epsilon-greedy policy
        if np.random.rand() < 0.1:
            # Explore: Randomly select trap locations
            trap_indices = []
            for _ in range(self.num_traps):
                trap_indices.append((np.random.randint(self.n), np.random.randint(self.m)))
            return trap_indices
        else:
            # Exploit: Select actions with highest Q-values
            state_tensor = torch.tensor(state, dtype=torch.float32)
            q_values = self.q_network(state_tensor)

            # Select the top num_traps indices with highest Q-values
            top_indices = torch.topk(q_values, self.num_traps).indices.tolist()
            
            # Convert indices to trap locations
            trap_indices = []
            for idx in top_indices:
                i = idx // self.m
                j = idx % self.m
                trap_indices.append((i, j))
            return trap_indices


    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def experience_replay(self, batch_size=32):
        if len(self.memory) < batch_size:
            return

        minibatch = np.random.choice(len(self.memory), batch_size, replace=False)
        states, targets = [], []
        for idx in minibatch:
            state, action, reward, next_state, done = self.memory[idx]
            states.append(state)
            q_values = self.q_network(torch.tensor([state], dtype=torch.float32)).detach().numpy()[0]
            if done:
                q_values[action] = reward
            else:
                next_q_values = self.q_network(torch.tensor([next_state], dtype=torch.float32)).detach().numpy()[0]
                q_values[action] = reward + 0.9 * np.max(next_q_values)
            targets.append(q_values)
        
        states = np.array(states, dtype=np.float32)
        targets = np.array(targets, dtype=np.float32)
        
        states_tensor = torch.tensor(states)
        targets_tensor = torch.tensor(targets)
        
        self.optimizer.zero_grad()
        q_values = self.q_network(states_tensor)
        loss = torch.nn.functional.mse_loss(q_values, targets_tensor)
        loss.backward()
        self.optimizer.step()

        
# Training loop
def train_agent(env, agent, num_episodes=1000):
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            state = np.array(state).flatten()
            action = agent.select_action(state)
            next_state_prey, next_state_pred, reward, done = env.step(action)
            next_state = np.concatenate((next_state_prey.flatten(), next_state_pred.flatten()))
            agent.remember(state, action, reward, next_state, done)
            agent.experience_replay()
            state = next_state
            total_reward += reward
        print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}")

done importing


In [10]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt

def plot_predator_locations_single(state, n):
    predator_density, prey_density = state[0], state[1]
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # One row, two columns
    titles = ['Predator Grid', 'Prey Grid']
    grids = [predator_density, prey_density]

    for ax, title, grid in zip(axes, titles, grids):
        im = ax.imshow(grid, cmap='YlGn')
        ax.set_title(title)
        ax.set_xlabel('Column')
        ax.set_ylabel('Row')
        ax.axis('on')
        plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, pad=0.04)

    plt.tight_layout()
    plt.show()
    


def predict(env, agent):
    state = env.reset()
    # # plot the initialized pred prey spread
    # plot_predator_locations_single(state,env.n)
    state = np.array(state).flatten()
    action = agent.select_action(state)
    print("Predicted trap locations for next episode:", action)
    return action

In [11]:
# Create environment and agent
env = TrapEnvironment()
agent = QLearningAgent(n=env.n, m= env.m, input_size=env.n * env.m*2, num_traps=env.num_traps)

# Train the agent
train_agent(env, agent,num_episodes=1)

# Predict and display predictions in interactive slider
trap_locations = predict(env, agent)

(25, 25, 2, 100) eskimo
(25, 25, 2, 100) eskimo
(25, 25, 2, 100) eskimo
(25, 25, 2, 100) eskimo


KeyboardInterrupt: 

In [45]:

# display predictions
trap_locations = predict(env, agent)
max_timestep = 24


# run and save a few timesteps of trap behavior with those locations
def generate_traps_next(env, trap_locations=trap_locations, num_traj=200, num_replacements=2, len_traj=50, pts_per_sec=100, save_loc='../Data/val_predict.npy', prey_range=(1, 5), predator_range=(1, 3)):
    # initialization stuff
    n = 25
    m = 25
    replacement_window = int(len_traj/num_replacements)
    dataset = np.zeros((num_traj, n * m * 2, len_traj * pts_per_sec))  # That will store each simulation
    t_span = [0, replacement_window]
    t_eval = np.linspace(0, replacement_window, replacement_window * pts_per_sec)  # Time vector

    # Change this line to configure how much you downsample the data, and the final time range
    downsample_rate = int(len(t_eval) / (replacement_window * pts_per_sec))
    idx = np.arange(0, len(t_eval), downsample_rate)

    # Generate random initial values for prey and predator populations within the specified ranges
    y0 = np.zeros((n, m, 2))

    # Generating inital points for both populations
    # set prey (invasive species) location at last timestep of initialization for new object where the traps will be the predators
    y0[:,:,0] = env.predator_density
    prey_data = env.prey_density

    # initialize trap locations based on predictions
    for i,j in trap_locations:
        y0[i, j, 1] = 10
    y0 = y0.flatten()

    # modify a run s.t. each timestep consists of two solvers: one for traps and one for prey
    # sol.y has shape (2 * n * m, 2500) with one row representing the prey(t) function and the other representing the predator(t)function
    # sol_use = None
    master_sol = np.ndarray((n*m*2,pts_per_sec))


    for _ in range(replacement_window):
        # trap solver, only grab single timestep
        sol = solve_ivp(synthetic_data.spatial_dynamics_traps, y0=y0, t_span=[0,1], t_eval=np.linspace(0, 1, pts_per_sec), args=(n, m))
        # prey solver y0 creation
        y_prey = np.zeros((n, m, 2))
        last_dim = int(pts_per_sec)
        sol_use = sol.y.reshape((n, m, 2, last_dim))
        pred_data_new, trap_data_new = sol_use[:, :, 0, :], sol_use[:, :, 1, :]
        # set prey from timestep of interest as predator in y_prey
        y_prey[:,:,1] = pred_data_new[:, :, -1]
        # grab predator information from prey data
        y_prey[:,:,0] = prey_data
        y_prey = y_prey.flatten()
        # prey solver, only grab single timestep
        sol_prey = solve_ivp(synthetic_data.spatial_dynamics, y0=y_prey, t_span=[0,1], t_eval=np.linspace(0, 1, pts_per_sec), args=(n, m))
        # create y0 for next run of trap solver, overwrite y0 and prey_data
        y0 = np.zeros((n, m, 2))
        sol_prey_use = sol_prey.y.reshape((n, m, 2, last_dim))
        prey_data, predator_data = sol_prey_use[:, :, 0, :], sol_prey_use[:, :, 1, :]
        y0[:,:,0] = predator_data[:, :, -1]
        prey_data = prey_data[:, :, -1]
        # initialize trap locations based on number of desired traps and density, re initialize per replacement time
        y0[:,:,1] = trap_data_new[:,:,-1]
        y0 = y0.flatten()
        master_sol = np.concatenate((master_sol, sol_prey.y), 1)

    # save master sol object
    print(master_sol[:,100:].shape)
    np.save(save_loc, master_sol[:,100:])


def plot_predator_locations_at_timestep_here(file_loc, timestep):
    dataset = np.load(file_loc)
    data = dataset.reshape((25, 25, 2, 2500))
    plot_predator_locations_here(data, timestep*50)

def plot_predator_locations_here(grid, timestep):
    #Get the min and max of all your data
    _min, _max = np.amin(grid), np.amax(grid)

    fig = plt.figure()
    ax = fig.add_subplot(2, 1, 1)
    #Add the vmin and vmax arguments to set the color scale
    ax.imshow(grid[:, :, 0, timestep], cmap=plt.cm.YlGn, vmin = _min, vmax = _max)
    ax2 = fig.add_subplot(2, 1, 2)
    #Add the vmin and vmax arguments to set the color scale
    ax2.imshow(grid[:, :, 1, timestep], cmap=plt.cm.YlGn, vmin = _min, vmax = _max)
    plt.show()

generate_traps_next(env,trap_locations)

# Create an interactive slider
interact(plot_predator_locations_at_timestep_here, file_loc='../Data/val_predict.npy', timestep=IntSlider(min=0, max=max_timestep, step=1, value=0))


Predicted trap locations for next episode: [(18, 15), (23, 14), (4, 2), (20, 7), (2, 0), (2, 12), (12, 9), (12, 4), (19, 4), (7, 20), (19, 0), (23, 15), (15, 1), (14, 8), (13, 24), (14, 6), (9, 16), (21, 23), (7, 10), (1, 19)]
(1250, 2500)


interactive(children=(Text(value='../Data/val_predict.npy', description='file_loc'), IntSlider(value=0, descri…

<function __main__.plot_predator_locations_at_timestep_here(file_loc, timestep)>