In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

affection_distance = 3
fuel_loss = 0.1 # how much fuel (in percentage) is consumed by a burning tree each tick.
ambient_temperature = 25
passive_cooling_coefficient = 0.05
drying_constant = 0.01 
ignition_constant = 1
dt = 0.1

def mix_rgb_linear(color1, color2, ratio=0.5):
    """
    Linearly mixes two RGB colors.
    
    color1 and color2 are tuples of (R, G, B) where values are 0-255.
    ratio is a float between 0.0 and 1.0, where 0.0 is all color1 and 
    1.0 is all color2. A 0.5 ratio is an even mix.
    """
    r1, g1, b1 = color1
    r2, g2, b2 = color2
    
    r_mixed = int(r1 * (1.0 - ratio) + r2 * ratio)
    g_mixed = int(g1 * (1.0 - ratio) + g2 * ratio)
    b_mixed = int(b1 * (1.0 - ratio) + b2 * ratio)
    
    # Ensure values are within the valid 0-255 range
    r_mixed = max(0, min(255, r_mixed))
    g_mixed = max(0, min(255, g_mixed))
    b_mixed = max(0, min(255, b_mixed))
    
    return [r_mixed, g_mixed, b_mixed]

class Cell:
    # Combustion temp is the temperature at which the tree combusts
    # Thermal conductivity is 0-1 of how much of the heat of surrounding cells this cell gets
    # dryness changes with temperature and affects heat resistance and combustion temp
    # fuel amount represents how much can be burnt. can be any non-negative number. a fuel of 0 means the cell cannot be burnt at all.
    # is flammalbe is whether this cell can be ignited (False if tree has already been burnt completely or if this is a ground cell)
    def __init__(self, combustion_temp, thermal_conductivity, dryness, fuel_amount, energy_density, is_ground, rng_generator, base_color):
        self.combustion_temp = combustion_temp
        self.thermal_conductivity = thermal_conductivity
        self.dryness = dryness
        self.fuel_amount = fuel_amount
        self.starting_fuel_amount = self.fuel_amount
        self.energy_density = energy_density
        self.is_ground = is_ground
        self.is_burning = False
        self.current_temperature = ambient_temperature
        self.rng_generator = rng_generator
        self.base_color = base_color
    
    def clone(self):
        new_cell = Cell(self.combustion_temp, self.thermal_conductivity, self.dryness, self.fuel_amount, self.energy_density, self.is_ground, self.rng_generator, self.base_color)
        new_cell.combustion_temp = self.combustion_temp
        new_cell.thermal_conductivity = self.thermal_conductivity
        new_cell.dryness = self.dryness
        new_cell.fuel_amount = self.fuel_amount
        new_cell.starting_fuel_amount = self.starting_fuel_amount
        new_cell.energy_density = self.energy_density
        new_cell.is_ground = self.is_ground
        new_cell.is_burning = self.is_burning
        new_cell.current_temperature = self.current_temperature
        new_cell.rng_generator = self.rng_generator
        new_cell.base_color = self.base_color

        return new_cell

    # ignites the tree
    def ignite(self):
        if self.is_flammable():
            self.is_burning = True
            self.current_temperature = self.combustion_temp
    
    def is_flammable(self):
        return self.fuel_amount > 0 and not self.is_ground
    
    def _get_color_internal(self):
        if self.is_burning:
            ratio = self.fuel_amount / self.starting_fuel_amount
            visible_ratio = ratio ** 0.5
            return mix_rgb_linear([30, 30, 30], [255, 69, 0], visible_ratio) # makes burning trees get darker as their fuel is spent
        elif self.starting_fuel_amount > 0 and self.fuel_amount == 0 and not self.is_ground:
            return [0, 0, 0]

        if self.is_flammable():
            # We want to see color change even with small temp rises.
            # This ratio hits 1.0 at combustion_temp and 0.0 at ambient.
            temp_diff = self.current_temperature - ambient_temperature
            max_diff = self.combustion_temp - ambient_temperature
            
            # Avoid division by zero and clamp 0-1
            ratio = max(0, min(1, temp_diff / max_diff))
            
            # Boost the ratio slightly so low-temp changes are visible (Gamma correction style)
            visible_ratio = ratio ** 0.5 
            
            return mix_rgb_linear(self.base_color, [255, 0, 0], visible_ratio)
        elif self.is_ground:
            return [200, 200, 200] # ground is very light gray
    
    def get_color(self):
        color = self._get_color_internal()
        return [color[0]/255.0, color[1]/255.0, color[2]/255.0]
    
    # called upon every cell in the forest, returns a new cell that is updated based on this cell's neighbors
    def update(forest, i, j, wind):
        cell_to_update = forest[i][j].clone()
        if cell_to_update.is_burning:
            # amount of fuel lost is an exponential decay, and when the fuel goes below fuel_loss it goes to 0 exactly (otherwise it just approaches zero)
            fuel_diff = -fuel_loss * cell_to_update.fuel_amount * dt
            if cell_to_update.fuel_amount < fuel_loss:
                fuel_diff = cell_to_update.fuel_amount
            
            cell_to_update.fuel_amount += fuel_diff

            cell_to_update.current_temperature += fuel_diff*cell_to_update.energy_density
            return cell_to_update
        
        # immediately combust a tree if it's beyond its combustion temperature (spontaneous combustion)
        if cell_to_update.current_temperature >= cell_to_update.combustion_temp:
            cell_to_update.ignite()
            return cell_to_update

        # Calculate the average temperature of the neighboring cells and move toward it
        total_temp = 0
        neighbor_count = 0

        for i_offset in range(-affection_distance, affection_distance+1):
            for j_offset in range(-affection_distance, affection_distance+1):
                new_i, new_j = i + i_offset - wind[0], j + j_offset - wind[1]
                if 0 <= new_i < forest.shape[0] and 0 <= new_j < forest.shape[1]:
                    # Weight by inverse distance to make it look realistic
                    dist = (i_offset**2 + j_offset**2)**0.5
                    if dist == 0: continue 
                    
                    total_temp += forest[new_i][new_j].current_temperature
                    neighbor_count += 1

        if neighbor_count > 0:
            avg_neighbor_temp = total_temp / neighbor_count
            # Heat moves toward the average neighbor temperature
            diffusion = (avg_neighbor_temp - cell_to_update.current_temperature) * cell_to_update.thermal_conductivity
            cell_to_update.current_temperature += diffusion * dt

        # This forces the temperature back toward ambient_temperature over time
        temp_loss = (cell_to_update.current_temperature - ambient_temperature) * passive_cooling_coefficient
        cell_to_update.current_temperature -= temp_loss * dt

        amount_of_burning_neighbors = 0

        for i_offset in range(-1, 2):
            for j_offset in range(-1, 2):
                if not (i_offset == 0 and j_offset == 0):
                    if (i+i_offset < np.size(forest, axis=0) and i+i_offset >= 0) and (j+j_offset < np.size(forest, axis=1) and j+j_offset >= 0):
                        if (forest[i+i_offset][j+j_offset].is_burning):
                            amount_of_burning_neighbors += 1

        # Only dry out if it's hotter than the starting temperature
        if cell_to_update.current_temperature > ambient_temperature:
            # Calculate the increase in dryness for this time step
            # We cap dryness at 1.0 (totally dry)
            dryness_increase = drying_constant * (cell_to_update.current_temperature - ambient_temperature) * dt
            cell_to_update.dryness = min(1.0, cell_to_update.dryness + dryness_increase)

        # chance to catch on fire is based on how many burning immediate neighbors the tree has, temperature, and dryness.
        ignition_risk = (cell_to_update.current_temperature / cell_to_update.combustion_temp) * cell_to_update.dryness * ignition_constant
        neighbor_factor = (amount_of_burning_neighbors + 1)

        # trees can't spontaneously combust unless very dry or they have burning neighbors
        if cell_to_update.dryness >= 0.6 or neighbor_factor > 1:
            if cell_to_update.rng_generator.random() < ignition_risk * neighbor_factor * dt:
                cell_to_update.ignite()
        return cell_to_update


numpy_seed = 5346

rng_seeded = np.random.default_rng(numpy_seed)

cell_types = [
    Cell( # represents one type of tree
        300, # combustion_temp
        0.2, # thermal_conductivity
        0.1, # dryness
        50, # fuel_amount
        2, # energy_density
        False, # is_ground
        rng_seeded, # rng_generator
        [46, 156, 6] # base_color
    ),
    Cell( # represents another type of tree
        430, # combustion_temp
        0.4, # thermal_conductivity
        0.0, # dryness
        40, # fuel_amount
        3, # energy_density
        False, # is_ground
        rng_seeded, # rng_generator
        [43, 69, 14] # base_color
    ),
    Cell( # represents a ground cell
        100_000, # combustion_temp
        0.0, # thermal_conductivity
        0.0, # dryness
        0, # fuel_amount
        0, # energy_density
        True, # is_ground
        rng_seeded, # rng_generator
        [171, 111, 51] # base_color
    )
]
chances = [
    0.4,
    0.4,
    0.2
]
shape = (10, 10)
starting_burning_amount = 2
time_limit = 60

# 1. Generate a 2D array of indices based on weights
indices = rng_seeded.choice(len(cell_types), size=shape, p=chances)

# 2. Create a vectorized function to handle the cloning
def clone_from_list(idx):
    return cell_types[idx].clone()

v_clone = np.vectorize(clone_from_list)

# 3. Generate the final array of cloned objects
forest_original = v_clone(indices)
wind = [2, 1]

for _ in range(starting_burning_amount):
    rand_i = rng_seeded.integers(0, np.size(forest_original, axis=0))
    rand_j = rng_seeded.integers(0, np.size(forest_original, axis=1))
    while not forest_original[rand_i][rand_j].is_flammable():
        rand_i = rng_seeded.integers(0, np.size(forest_original, axis=0))
        rand_j = rng_seeded.integers(0, np.size(forest_original, axis=1))
    forest_original[rand_i][rand_j].ignite()

def clone_forest(forest) -> Cell:
    return np.array([[item.clone() for item in row] for row in forest])

forest = []

def get_color_grid_from_forest(forest):
    # 1. Define a function to extract the color from a single object
    # Assuming get_color() returns something like [R, G, B] or (R, G, B)
    def extract_color(obj):
        return obj.get_color()

    # 2. Vectorize the extraction
    # np.frompyfunc(func, nin, nout) -> nin: inputs, nout: outputs
    v_get_color = np.frompyfunc(extract_color, 1, 1)

    # 3. Apply it to your cloned_array
    # This creates an array of color objects (tuples/lists)
    color_objects = v_get_color(forest)

    # 4. Convert the "array of lists" into a standard 3D numerical array
    # np.stack handles the conversion from (M, N) objects to (M, N, 3) floats/ints
    return np.stack(color_objects.tolist()).astype(float)



import sys
import time
# helper function I created that displays progress for animations and other long processes
def print_progress(current_iteration, max_iterations, display_eta = True, eta_update_rate = 3):
    if not hasattr(print_progress, "start_seconds"):
        print_progress.start_seconds = time.time() # Initialize on first call
    elif current_iteration == 0:
        print_progress.start_seconds = time.time()
    
    if not hasattr(print_progress, "last_time_seconds"):
        print_progress.last_time_seconds = time.time() - 0.000001
    elif current_iteration == 0:
        print_progress.last_time_seconds = time.time() - 0.000001
    

    
    progress = current_iteration/max_iterations
    sys.stdout.write("\r")
    sys.stdout.write(str(int(100*progress)) + "%")
    if display_eta and current_iteration % eta_update_rate == 1:
        sys.stdout.write(" Estimated total time: " + str(((1/progress)*(time.time()-print_progress.start_seconds))) + " seconds")
    sys.stdout.flush()

    print_progress.last_time_seconds = time.time()



def init():
    global forest
    forest = clone_forest(forest_original)


def advance(index, write_progress = False):
    if write_progress:
        print_progress(index, int(time_limit/dt))
    if index == 0:
        init()
    global forest
    temp_forest = np.empty(shape, dtype=object)
    for i in range(np.size(forest, axis=0)):
        for j in range(np.size(forest, axis=1)):
            temp_forest[i][j] = Cell.update(forest, i, j, wind)
    forest = temp_forest

## Single run as video

In [None]:
forest = clone_forest(forest_original)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(get_color_grid_from_forest(forest), animated=True)
passed_frame_0 = False

def anim_init():
    global forest, passed_frame_0

    init()
    passed_frame_0 = False

    im.set_array(get_color_grid_from_forest(forest))
    return (im, )

def anim_advance(frame):
    global passed_frame_0
    if frame == 0 and passed_frame_0:
            return (im, )
    global forest

    if not passed_frame_0:
        passed_frame_0 = True
    
    advance(frame, True)
    
    im.set_array(get_color_grid_from_forest(forest))
    return (im, )

anim = FuncAnimation(
    fig,
    anim_advance,
    init_func=anim_init,
    frames=int(time_limit/dt),
    interval=20,
    blit=True
)

HTML(anim.to_jshtml())  # show as html