<a href="https://colab.research.google.com/github/jamessutton600613-png/GC/blob/main/Untitled178tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: Mountpoint must not already contain files

In [None]:
import pandas as pd
import numpy as np # Import numpy as it might be needed for calculations

# Assuming simulation_data is a list of state dictionaries loaded from checkpoints
# and loaded_steps is a list of corresponding step numbers.
# These variables should be available from the previous data loading cell.

if not simulation_data:
    print("No simulation data loaded. Please ensure the data loading cell ran successfully.")
else:
    # Create lists to store the extracted metrics
    steps = []
    cautious_pops = []
    reckless_pops = []
    active_protos = []
    arrested_protos = []
    inactive_protos = []
    avg_damages = []
    total_pops = [] # Add total population for context

    # Iterate through the loaded data
    for i, state in enumerate(simulation_data):
        step = loaded_steps[i] # Get the step number
        steps.append(step)

        # Calculate population counts and other metrics from the state dictionary
        # Ensure we only count agents with status == 1 (alive)
        cautious_pop = np.sum((state['strategy'] == 0) & (state['status'] == 1))
        reckless_pop = np.sum((state['strategy'] == 1) & (state['status'] == 1))
        active_proto = np.sum((state['protoribosome_state'] == PROTO_STATE_ACTIVE) & (state['status'] == 1))
        arrested_proto = np.sum((state['protoribosome_state'] == PROTO_STATE_ARRESTED) & (state['status'] == 1))
        inactive_proto = np.sum((state['protoribosome_state'] == PROTO_STATE_INACTIVE) & (state['status'] == 1))
        avg_damage = np.mean(state['rna_damage'][state['status'] == 1]) if np.sum(state['status'] == 1) > 0 else 0.0 # Calculate average damage for living agents
        total_pop = np.sum(state['status'] == 1) # Total living population

        cautious_pops.append(cautious_pop)
        reckless_pops.append(reckless_pop)
        active_protos.append(active_proto)
        arrested_protos.append(arrested_proto)
        inactive_protos.append(inactive_proto)
        avg_damages.append(avg_damage)
        total_pops.append(total_pop)

    # Create a pandas DataFrame
    simulation_df = pd.DataFrame({
        'Step': steps,
        'Cautious Population': cautious_pops,
        'Reckless Population': reckless_pops,
        'Total Population': total_pops, # Add total population to DataFrame
        'Active Protoribosomes': active_protos,
        'Arrested Protoribosomes': arrested_protos,
        'Inactive Protoribosomes': inactive_protos,
        'Average RNA Damage': avg_damages
    })

    print("\nSimulation data processed into a DataFrame:")
    display(simulation_df.head()) # Display the first few rows of the DataFrame
    display(simulation_df.tail()) # Display the last few rows

# The simulation_df DataFrame is now ready for plotting.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns # Import seaborn for potentially nicer plots

# Assuming the simulation_df DataFrame is available from the data processing cell.

if 'simulation_df' not in locals() or simulation_df.empty:
    print("Simulation DataFrame not found or is empty. Please ensure the data processing cell ran successfully.")
else:
    print("\nGenerating plots...")

    # --- Plotting Population Counts Over Time ---
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=simulation_df, x='Step', y='Cautious Population', label='Cautious')
    sns.lineplot(data=simulation_df, x='Step', y='Reckless Population', label='Reckless')
    sns.lineplot(data=simulation_df, x='Step', y='Total Population', label='Total') # Plot total population
    plt.xlabel('Simulation Step')
    plt.ylabel('Population Count')
    plt.title('Population Dynamics Over Time (Cautious vs Reckless)')
    plt.legend()
    plt.grid(True)
    plt.show()

    # --- Plotting Protoribosome States Over Time ---
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=simulation_df, x='Step', y='Active Protoribosomes', label='Active')
    sns.lineplot(data=simulation_df, x='Step', y='Arrested Protoribosomes', label='Arrested')
    sns.lineplot(data=simulation_df, x='Step', y='Inactive Protoribosomes', label='Inactive')
    plt.xlabel('Simulation Step')
    plt.ylabel('Number of Agents')
    plt.title('Protoribosome States Over Time')
    plt.legend()
    plt.grid(True)
    plt.show()


    # --- Plotting Average RNA Damage Over Time ---
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=simulation_df, x='Step', y='Average RNA Damage')
    plt.xlabel('Simulation Step')
    plt.ylabel('Average RNA Damage')
    plt.title('Average RNA Damage Over Time')
    plt.grid(True)
    plt.show()

    print("Plots generated.")


In [None]:
## Simulation Results Analysis

Now that you have the plots for population counts, protoribosome states, and average RNA damage over time, you can analyze the simulation results:

1.  **Population Dynamics Plot**:
    *   Observe how the "Cautious" and "Reckless" populations change over the 10,000 simulation steps.
    *   Did one strategy outcompete the other?
    *   Does the total population remain stable, grow, or decline?
    *   Look for fluctuations or trends that might correlate with the daily and seasonal weather cycles (though the weather itself isn't plotted here).

2.  **Protoribosome States Plot**:
    *   Examine the proportions of agents with "Active", "Arrested", and "Inactive" protoribosomes over time.
    *   How do these states relate to the overall population changes?
    *   Does the number of "Arrested" protoribosomes increase with damage or stop codons?
    *   Does the number of "Inactive" protoribosomes increase, perhaps indicating successful peptide production or other factors leading to inactivation?

3.  **Average RNA Damage Plot**:
    *   Track the average RNA damage level across all living agents over time.
    *   Does the damage level fluctuate with the weather cycles (UV and temperature)?
    *   Does the average damage level stabilize, increase, or decrease over the long term?
    *   How does the average damage relate to the population dynamics and protoribosome states?

By examining these plots together, you can gain insights into how the different strategies, environmental conditions, and the mutation/repair/protoribosome mechanisms influenced the population's survival and composition over time.

---

This concludes the planned steps for running the simulation and visualizing the initial results.

**Next Steps:**

*   **Interpret the plots**: Carefully analyze the generated graphs based on the points above.
*   **Refine parameters**: If the simulation results are not as expected or you want to explore different scenarios, you can go back to the initial setup cell and modify parameters like mutation rates, repair rates, damage thresholds, weather patterns, etc., and re-run the simulation and plotting steps.
*   **Further analysis**: You might want to perform more detailed analysis on the final population state (e.g., examine the RNA sequences of surviving agents).
*   **Experiment with strategies**: You could modify the simulation logic to introduce new strategies or change how existing strategies behave.

# Task
Implement a simulation loop with checkpointing and resume functionality for the provided population state and simulation parameters.

## Define simulation step logic

### Subtask:
Outline the individual processes that occur within a single simulation step (e.g., applying environmental effects like UV damage, handling mutations and repair, updating protoribosome states, implementing strategy-based actions, handling reproduction/death, etc.).


**Reasoning**:
Outline the individual processes within a single simulation step based on the provided context and the subtask instructions. This will involve describing environmental effects, damage/repair, mutations, protoribosome updates, strategy actions, and reproduction/death rules.



In [None]:
# 1. Environmental Factors:
#    - UV radiation: Increases RNA damage. Intensity could vary daily (peak at noon) and seasonally (higher in summer). Represented by an environmental variable that scales damage application.
#    - Temperature: Could affect reaction rates (damage, repair) or metabolic costs. Could also vary daily and seasonally. Represented by an environmental variable.

# 2. RNA Damage and Repair:
#    - Damage Application: In each step, apply damage to each base in the 'rna_sequence'. The amount of damage applied per base is influenced by the current UV radiation level. This updates the 'rna_damage' array.
#    - Repair Mechanism: Agents might have a repair rate. In each step, damage could be reduced on each base based on the repair rate. The repair rate could potentially be influenced by strategy or temperature. This also updates the 'rna_damage' array, reducing values towards zero.

# 3. Mutations:
#    - Mutation Occurrence: Mutations can occur spontaneously or potentially be influenced by high damage levels. In each step, with a certain probability per base, a base in the 'rna_sequence' can change to a random other base. This updates the 'rna_sequence' array.

# 4. Protoribosome State Update:
#    - States: ACTIVE, ARRESTED, INACTIVE.
#    - Transition to ARRESTED: A protoribosome transitions from ACTIVE to ARRESTED if the damage on a base exceeds a threshold during translation or if a stop codon is encountered. The 'has_stop_codon' field, determined from the 'rna_sequence', is key here. High 'rna_damage' values can also trigger this transition.
#    - Transition from ARRESTED: An arrested protoribosome might become ACTIVE again if damage is repaired below a threshold or if the stop codon is somehow resolved (less likely in this simple model). It might transition to INACTIVE if it stays arrested for too long or if the cell dies.
#    - Transition to INACTIVE: A protoribosome might become INACTIVE after successful translation (producing a "peptide" - implicitly assumed if not arrested) or if the agent's status becomes dead.
#    - Updates the 'protoribosome_state' array.

# 5. Strategy-Based Actions:
#    - Cautious Strategy (0): May have a higher repair rate, a lower damage threshold for protoribosome arrest, or slower translation speed to reduce damage accumulation. They might invest more resources in repair.
#    - Reckless Strategy (1): May have a lower repair rate, a higher damage threshold for protoribosome arrest, or faster translation speed, potentially leading to faster resource acquisition but also higher risk of damage and arrest. They might invest less in repair.
#    - Strategy influences how damage and protoribosome states affect survival and reproduction.

# 6. Reproduction and Death:
#    - Reproduction: Agents might reproduce if they accumulate enough resources (not explicitly modeled yet, but could be linked to successful protein production) and are healthy (low damage, active protoribosomes). Reproduction creates a new agent with a potentially mutated RNA sequence.
#    - Death: Agents die if their damage level is too high, their protoribosome is arrested for too long, or they fail to reproduce/acquire resources over a period. The 'status' field is set to 0 for dead agents. Dead agents are removed from active consideration but might be kept in the state for analysis.

# 7. Has Stop Codon Determination:
#    - In each step, or when the 'rna_sequence' changes, check if the sequence contains a stop codon (e.g., a specific sequence of 3 bases like UAA, UAG, UGA represented by integers).
#    - Updates the 'has_stop_codon' boolean array for each agent.

# 8. Other Processes:
#    - Resource Dynamics: While not explicitly defined, resource acquisition (linked to successful translation) and consumption (for maintenance, repair, reproduction) would be implicit drivers of survival and reproduction.
#    - Agent Removal: Dead agents ('status' == 0) are effectively removed from the living population for subsequent calculations but their data might be retained for analysis of extinction dynamics.
#    - Step Increment: The simulation step counter is incremented.
#    - Checkpointing: Periodically save the full 'population_state' dictionary to a file.

## Implement the simulation loop

### Subtask:
Write the main loop that iterates for the specified number of simulation steps.


**Reasoning**:
Define the total number of simulation steps and initialize the simulation loop with a placeholder print statement as requested by the instructions.



In [None]:
# 1. Define the total number of simulation steps
NUM_STEPS = 10000

# Determine the starting step. If simulation_data is not empty, start from the step after the last loaded one.
# Otherwise, start from 0.
start_step = loaded_steps[-1] + 1 if loaded_steps else 0

print(f"Starting simulation from step {start_step} for a total of {NUM_STEPS} steps.")

# 2. Initialize a loop that runs from the starting step up to NUM_STEPS.
# 3. Inside the loop, call a function or include the logic for a single simulation step (placeholder).
# 4. Include a placeholder print statement inside the loop to indicate the current simulation step.
# 5. Ensure the loop updates a variable tracking the current simulation step.
for step in range(start_step, NUM_STEPS + 1):
    # Placeholder for simulation step logic (will be implemented later)
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For now, we just have the print statement.

    # Placeholder print statement
    if step % 100 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}")

    # The loop variable 'step' automatically tracks the current simulation step.

    # Placeholder for checkpointing logic (will be implemented later)
    # if step % CHECKPOINT_INTERVAL == 0:
    #     save_checkpoint(population_state, step)

print("Simulation loop initialized.")

Starting simulation from step 0 for a total of 10000 steps.
Simulating step 0/10000
Simulating step 100/10000
Simulating step 200/10000
Simulating step 300/10000
Simulating step 400/10000
Simulating step 500/10000
Simulating step 600/10000
Simulating step 700/10000
Simulating step 800/10000
Simulating step 900/10000
Simulating step 1000/10000
Simulating step 1100/10000
Simulating step 1200/10000
Simulating step 1300/10000
Simulating step 1400/10000
Simulating step 1500/10000
Simulating step 1600/10000
Simulating step 1700/10000
Simulating step 1800/10000
Simulating step 1900/10000
Simulating step 2000/10000
Simulating step 2100/10000
Simulating step 2200/10000
Simulating step 2300/10000
Simulating step 2400/10000
Simulating step 2500/10000
Simulating step 2600/10000
Simulating step 2700/10000
Simulating step 2800/10000
Simulating step 2900/10000
Simulating step 3000/10000
Simulating step 3100/10000
Simulating step 3200/10000
Simulating step 3300/10000
Simulating step 3400/10000
Simulat

## Integrate checkpointing

### Subtask:
Add code within the loop to save the `population_state` to pickle files in the `CHECKPOINT_DIR` periodically. This will allow for resuming the simulation and for later analysis.


**Reasoning**:
Add checkpointing logic inside the simulation loop to save the population state periodically.



In [None]:
import pickle
import os

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Check if simulation_data was loaded. If so, resume from the last loaded state.
if simulation_data:
    population_state = simulation_data[-1]
    start_step = loaded_steps[-1] + 1
    print(f"Resuming simulation from step {start_step} with loaded state.")
else:
    # If no data was loaded, use the initial state defined earlier
    # The initial_strategies, initial_rna_sequences, initial_protoribosome_states
    # and the initial population_state dictionary should be available from previous cells.
    # We also need to ensure the initial_state is saved.
    start_step = 0
    print(f"Starting new simulation from step {start_step} with initial state.")
    # Save the initial state
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    with open(initial_state_filename, 'wb') as f:
        pickle.dump(population_state, f)
    print(f"Saved initial state to {initial_state_filename}")


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Placeholder for simulation step logic (will be implemented later)
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For now, we just have the print statement.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")


Starting new simulation from step 0 with initial state.
Saved initial state to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/initial_state.pkl
Running simulation from step 0 to 10000
Simulating step 0/10000
Simulating step 1000/10000
Checkpoint saved at step 1000 to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/checkpoint_step_1000.pkl
Simulating step 2000/10000
Checkpoint saved at step 2000 to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/checkpoint_step_2000.pkl
Simulating step 3000/10000
Checkpoint saved at step 3000 to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/checkpoint_step_3000.pkl
Simulating step 4000/10000
Checkpoint saved at step 4000 to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/checkpoint_step_4000.pkl
Simulating step 5000/10000
Checkpoint saved at step 5000 to /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_0

## Integrate resume functionality

### Subtask:
Modify the simulation loop to check if checkpoint files exist in the `CHECKPOINT_DIR` at the start. If they do, load the latest state to resume the simulation.


## Combine with initialization and plotting

### Subtask:
Show how the new simulation loop code connects with the existing initialization and plotting cells to create a complete workflow.


**Reasoning**:
Add markdown cells and comments to explain the flow of the notebook and the purpose of each code cell to show how the new simulation loop connects with the existing initialization and plotting cells.



In [None]:
# Add markdown cells and comments to explain the flow of the notebook and the purpose of each code cell

**Reasoning**:
Add markdown cells and comments to explain the flow of the notebook and the purpose of each code cell to show how the new simulation loop connects with the existing initialization and plotting cells.



In [None]:
# This notebook simulates the evolution of a population of agents with different strategies
# and RNA repair mechanisms under environmental pressures.

# --- Section 1: Initialization ---
# This section initializes the simulation parameters and the initial state of the agent population.

# Code cell: Define simulation parameters and initial population state
# Defines constants like NUM_AGENTS, SEQUENCE_LENGTH, base and protoribosome states.
# Initializes the population_state dictionary with initial strategies, RNA sequences,
# RNA damage, stop codon status, agent status (alive/dead), and protoribosome states.

# --- Section 2: Data Loading and Checkpoint Setup ---
# This section sets up the directory for saving checkpoints and attempts to load
# any existing checkpoint data to resume a previous simulation run.

# Code cell: Load simulation data from checkpoints
# Defines the directory for checkpoints (CHECKPOINT_DIR).
# Creates the directory if it doesn't exist.
# Lists, sorts, and loads existing checkpoint files into simulation_data and loaded_steps.
# This allows the simulation to resume from the last saved state.

# --- Section 3: Simulation Loop with Checkpointing ---
# This is the main simulation loop where the state of the population is updated over time.
# It includes the logic for applying simulation processes and saving checkpoints periodically.

# Code cell: Main simulation loop with checkpointing and resume functionality
# Determines the starting step based on whether checkpoint data was loaded.
# Sets the total number of simulation steps (NUM_STEPS).
# If resuming, loads the population state from the last checkpoint.
# If starting new, saves the initial state.
# Iterates from the start_step to NUM_STEPS.
# *** Placeholder for the actual simulation step logic *** (This is where environmental effects,
# damage, repair, mutations, protoribosome updates, reproduction, and death would be implemented).
# Saves the population_state to a checkpoint file at regular intervals (CHECKPOINT_INTERVAL)
# and at the final step.

# --- Section 4: Data Processing ---
# After the simulation is complete (or after loading checkpoints for analysis),
# this section processes the loaded state data into a structured format (pandas DataFrame)
# suitable for analysis and plotting.

# Code cell: Process simulation data into a DataFrame
# Checks if simulation_data was loaded.
# Iterates through each state dictionary in simulation_data.
# Extracts metrics like population counts (cautious, reckless, total),
# protoribosome state counts (active, arrested, inactive), and average RNA damage.
# Creates a pandas DataFrame (simulation_df) with these metrics indexed by simulation step.

# --- Section 5: Data Visualization ---
# This section generates plots based on the processed simulation data in the DataFrame
# to visualize the dynamics of the population and other key metrics over time.

# Code cell: Generate plots from simulation data
# Checks if the simulation_df DataFrame exists and is not empty.
# Uses matplotlib and seaborn to create line plots:
# - Population counts over time (Cautious, Reckless, Total).
# - Protoribosome states over time (Active, Arrested, Inactive).
# - Average RNA damage over time.
# Displays the generated plots.

# --- Section 6: Analysis and Next Steps ---
# Provides guidance on interpreting the generated plots and suggests potential
# next steps for further analysis or simulation experiments.

# Code cell: Simulation Results Analysis (Markdown)
# Contains explanatory text guiding the user on how to analyze the plots
# and what insights to look for regarding population dynamics, protoribosome states,
# and RNA damage in relation to the simulation parameters and processes.
# Suggests next steps like interpreting plots, refining parameters, further analysis,
# and experimenting with strategies.

## Summary:

### Data Analysis Key Findings

*   The simulation step logic was outlined, detailing processes such as environmental factors, RNA damage and repair, mutations, protoribosome state updates, strategy-based actions, reproduction and death, and stop codon determination.
*   A main simulation loop was implemented to iterate for a specified number of steps, incorporating placeholders for the simulation step logic and progress tracking.
*   Checkpointing functionality was integrated into the loop to periodically save the `population_state` to pickle files, including saving the initial and final states.
*   The resume functionality, which checks for and loads the latest checkpoint at the start of the simulation, was confirmed to be already implemented.
*   The complete workflow, connecting initialization, data loading (including resume), the simulation loop with checkpointing, data processing, and plotting, was outlined using markdown and comments.

### Insights or Next Steps

*   The current implementation provides a solid framework for an evolutionary simulation. The next crucial step is to replace the placeholder in the simulation loop with the actual logic for updating the `population_state` based on the detailed processes outlined in the first step.
*   To enable meaningful analysis, the data processing section should be fully implemented to transform the saved `population_state` data from the checkpoints into a structured format (like a pandas DataFrame) that can be used for plotting and quantitative analysis of population dynamics, protoribosome states, and damage levels over time.


# Task
Implement the core simulation step logic for the agent-based model, including environmental factors, RNA damage and repair, mutations, stop codon checks, protoribosome state updates, strategy-based actions, and agent reproduction and death.

## Define environmental factors

### Subtask:
Implement functions or logic to calculate environmental variables like UV radiation and temperature for each simulation step, considering daily and seasonal cycles.


**Reasoning**:
Implement functions for daily and seasonal cycles for UV and temperature and integrate their calculation into the simulation loop.



In [None]:
# Define constants for environmental cycles
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# --- Main simulation loop with checkpointing ---
# (Existing code from the previous cell)
# ...

# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")

NameError: name 'start_step' is not defined

**Reasoning**:
The error `NameError: name 'start_step' is not defined` indicates that the variable `start_step` was not defined before being used in the print statement and the loop. This is because the previous code block only contained the environmental calculation functions and the simulation loop structure, but not the logic to determine `start_step` based on loaded checkpoints. The code needs to include the logic from the checkpoint loading cell to correctly define `start_step` and `population_state` before the simulation loop begins.



In [None]:
import pickle
import os
import jax.numpy as jnp # Ensure jnp is imported for the functions

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved
# Assuming the checkpoint files are saved in '/content/drive/MyDrive/Colab_Notebooks/Sim_Experiments/Replica_1/'
# CHECKPOINT_DIR should be defined in a previous cell
# CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# Initialize an empty list called simulation_data to store the loaded checkpoint data.
simulation_data = []
loaded_steps = [] # To keep track of the step corresponding to each loaded state

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
# re should be imported in a previous cell
# import re
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Now load the sorted files
print("Loading simulation data from checkpoints...")
# Only load the last checkpoint if resuming
if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation.")
        start_step = 0
        # The initial_strategies, initial_rna_sequences, initial_protoribosome_states
        # and the initial population_state dictionary should be available from previous cells.
        # We also need to ensure the initial_state is saved.
        # Define initial_population_state if it's not guaranteed to be in the global scope
        # population_state = { ... initial state definition ... }
        print(f"Starting new simulation from step {start_step} with initial state.")
        # Save the initial state
        initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
        try:
             with open(initial_state_filename, 'wb') as f:
                 pickle.dump(population_state, f)
             print(f"Saved initial state to {initial_state_filename}")
        except Exception as e:
             print(f"Error saving initial state: {e}")

else:
    # If no data was loaded, use the initial state defined earlier
    start_step = 0
    # The initial_strategies, initial_rna_sequences, initial_protoribosome_states
    # and the initial population_state dictionary should be available from previous cells.
    # We also need to ensure the initial_state is saved.
    # Define initial_population_state if it's not guaranteed to be in the global scope
    # population_state = { ... initial state definition ... }
    print(f"No checkpoints found. Starting new simulation from step {start_step} with initial state.")
    # Save the initial state
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    try:
         with open(initial_state_filename, 'wb') as f:
             pickle.dump(population_state, f)
         print(f"Saved initial state to {initial_state_filename}")
    except Exception as e:
         print(f"Error saving initial state: {e}")

# Define constants for environmental cycles (repeated here for clarity, assuming they might not be in global scope)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")

NameError: name 'CHECKPOINT_DIR' is not defined

**Reasoning**:
The error `NameError: name 'CHECKPOINT_DIR' is not defined` indicates that the `CHECKPOINT_DIR` variable was not defined before being used. While the notebook context shows it was defined in a previous cell, the current code block needs to ensure it's defined within its scope or rely on it being a global variable. Since the previous cell defined it, I will assume it's in the global scope and remove the redundant definition from the current code block, but keep the directory creation and checkpoint loading logic.



In [None]:
import pickle
import os
import jax.numpy as jnp # Ensure jnp is imported for the functions
import re # Ensure re is imported

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# CHECKPOINT_DIR is assumed to be defined in a previous cell

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# Initialize an empty list called simulation_data to store the loaded checkpoint data.
# simulation_data and loaded_steps are not strictly needed here if we only load the last state
# but keeping them for potential future use or consistency with previous cells.
simulation_data = []
loaded_steps = [] # To keep track of the step corresponding to each loaded state

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation.")
        start_step = 0
        # Assuming initial_population_state is available from a previous cell
        # population_state = initial_population_state # Use the initial state
        print(f"Starting new simulation from step {start_step} with initial state.")
        # Save the initial state
        initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
        try:
             with open(initial_state_filename, 'wb') as f:
                 pickle.dump(population_state, f)
             print(f"Saved initial state to {initial_state_filename}")
        except Exception as e:
             print(f"Error saving initial state: {e}")

else:
    # If no data was loaded, use the initial state defined earlier
    start_step = 0
    # Assuming initial_population_state is available from a previous cell
    # population_state = initial_population_state # Use the initial state
    print(f"No checkpoints found. Starting new simulation from step {start_step} with initial state.")
    # Save the initial state
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    try:
         with open(initial_state_filename, 'wb') as f:
             pickle.dump(population_state, f)
         print(f"Saved initial state to {initial_state_filename}")
    except Exception as e:
         print(f"Error saving initial state: {e}")

# Define constants for environmental cycles (assuming they might not be in global scope)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")

NameError: name 'CHECKPOINT_DIR' is not defined

**Reasoning**:
The `NameError: name 'CHECKPOINT_DIR' is not defined` persists, indicating that `CHECKPOINT_DIR` is not reliably available in the global scope when this cell is executed. To fix this, I will explicitly define `CHECKPOINT_DIR` within this code block, as it is essential for the checkpointing and resume logic. I will also ensure `population_state` is correctly initialized in the case where no checkpoints are found by referencing the initial population state defined in an earlier cell.



In [None]:
import pickle
import os
import jax.numpy as jnp
import re

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved - Explicitly define it here
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# Initialize an empty list called simulation_data to store the loaded checkpoint data.
simulation_data = []
loaded_steps = [] # To keep track of the step corresponding to each loaded state

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation.")
        start_step = 0
        # Use the initial_population_state defined in a previous cell
        population_state = initial_population_state # Assuming this variable is available
        print(f"Starting new simulation from step {start_step} with initial state.")
        # Save the initial state
        initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
        try:
             with open(initial_state_filename, 'wb') as f:
                 pickle.dump(population_state, f)
             print(f"Saved initial state to {initial_state_filename}")
        except Exception as e:
             print(f"Error saving initial state: {e}")

else:
    # If no data was loaded, use the initial state defined earlier
    start_step = 0
    # Use the initial_population_state defined in a previous cell
    population_state = initial_population_state # Assuming this variable is available
    print(f"No checkpoints found. Starting new simulation from step {start_step} with initial state.")
    # Save the initial state
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    try:
         with open(initial_state_filename, 'wb') as f:
             pickle.dump(population_state, f)
         print(f"Saved initial state to {initial_state_filename}")
    except Exception as e:
         print(f"Error saving initial state: {e}")

# Define constants for environmental cycles (assuming they might not be in global scope)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")

Ensured checkpoint directory exists: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50


NameError: name 'initial_population_state' is not defined

**Reasoning**:
The error `NameError: name 'initial_population_state' is not defined` indicates that when no checkpoints are found, the code attempts to assign the variable `initial_population_state` to `population_state`, but this variable is not defined in the current scope. The initial population state is actually stored in the variable `population_state` which was initialized in the first code block of the notebook. I need to correct the logic to use the globally defined `population_state` when starting a new simulation instead of referencing a non-existent variable.



In [None]:
import pickle
import os
import jax.numpy as jnp
import re

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved - Explicitly define it here
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# Initialize an empty list called simulation_data to store the loaded checkpoint data.
simulation_data = []
loaded_steps = [] # To keep track of the step corresponding to each loaded state

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation.")
        start_step = 0
        # Use the globally defined population_state for the initial state
        # population_state is already defined in the first code block
        print(f"Starting new simulation from step {start_step} with initial state.")
        # Save the initial state
        initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
        try:
             with open(initial_state_filename, 'wb') as f:
                 pickle.dump(population_state, f)
             print(f"Saved initial state to {initial_state_filename}")
        except Exception as e:
             print(f"Error saving initial state: {e}")

else:
    # If no data was loaded, use the initial state defined earlier
    start_step = 0
    # Use the globally defined population_state for the initial state
    # population_state is already defined in the first code block
    print(f"No checkpoints found. Starting new simulation from step {start_step} with initial state.")
    # Save the initial state
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    try:
         with open(initial_state_filename, 'wb') as f:
             pickle.dump(population_state, f)
         print(f"Saved initial state to {initial_state_filename}")
    except Exception as e:
         print(f"Error saving initial state: {e}")

# Define constants for environmental cycles (assuming they might not be in global scope)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    if step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")


print("Simulation loop finished.")

Ensured checkpoint directory exists: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50
No checkpoints found. Starting new simulation from step 0 with initial state.
Error saving initial state: name 'population_state' is not defined
Running simulation from step 0 to 10000
Simulating step 0/10000, UV: 0.05, Temp: 24.24째C
Simulating step 1000/10000, UV: 0.49, Temp: 13.36째C
Error saving checkpoint at step 1000: name 'population_state' is not defined
Simulating step 2000/10000, UV: 0.73, Temp: 5.75째C
Error saving checkpoint at step 2000: name 'population_state' is not defined
Simulating step 3000/10000, UV: 0.13, Temp: 16.02째C
Error saving checkpoint at step 3000: name 'population_state' is not defined
Simulating step 4000/10000, UV: 1.15, Temp: 17.39째C
Error saving checkpoint at step 4000: name 'population_state' is not defined
Simulating step 5000/10000, UV: 1.12, Temp: 20.05째C
Error saving checkpoint at step 5000: name 'population_state' is not defined
Simulating 

**Reasoning**:
The error `Error saving initial state: name 'population_state' is not defined` and similar errors during checkpointing indicate that while `population_state` was intended to be globally available from the first cell, it's not correctly accessible when the checkpoint loading logic and simulation loop are executed. To fix this, I need to explicitly ensure that `population_state` is defined and available within this code block, either by re-initializing it if no checkpoints are loaded, or by relying on it being loaded from a checkpoint. Since the goal is to have resume functionality, I'll adjust the logic to load from a checkpoint first, and *only* if no checkpoints are found, then initialize `population_state` with the initial values (which are assumed to be available from the first cell). This will ensure `population_state` is always defined before attempting to save it.



In [None]:
import pickle
import os
import jax.numpy as jnp
import re

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved - Explicitly define it here
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
population_state = None # Initialize population_state to None
start_step = 0

if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Will attempt to use initial state.")
    except Exception as e:
        print(f"Error loading last checkpoint {last_file_path}: {e}. Will attempt to use initial state.")


# If population_state is still None (meaning no checkpoints were loaded or loading failed),
# use the initial state defined in the first cell.
if population_state is None:
    # Assuming the initial population_state dictionary is available from the first code block
    # It's defined as 'population_state' in the first cell.
    # We might need to make a copy if it's modified in place within the loop later.
    # population_state = initial_population_state.copy() # Use the initial state
    # Let's assume the global `population_state` from the first cell is the intended initial state.
    print(f"No valid checkpoints found or loaded. Starting new simulation from step {start_step} with initial state.")
    # Save the initial state if starting from step 0 and it hasn't been saved
    if start_step == 0:
        initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
        # Check if the initial state file already exists to avoid overwriting on accidental re-run
        if not os.path.exists(initial_state_filename):
             try:
                  # Ensure population_state from the first cell is accessible here.
                  # If not, explicitly define it or pass it into this cell's execution.
                  # For now, assuming it's globally available.
                  with open(initial_state_filename, 'wb') as f:
                      pickle.dump(population_state, f) # Use the global population_state
                  print(f"Saved initial state to {initial_state_filename}")
             except Exception as e:
                  print(f"Error saving initial state: {e}")
        else:
             print(f"Initial state file already exists: {initial_state_filename}. Not overwriting.")


# Define constants for environmental cycles (assuming they might not be in global scope)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Peak temp around summer solstice

# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

for step in range(start_step, NUM_STEPS + 1):
    # Calculate environmental variables for the current step
    current_uv_radiation = calculate_uv(step)
    current_temperature = calculate_temperature(step)

    # These variables (current_uv_radiation, current_temperature) are now available
    # to be used in the simulation step logic (which is currently a placeholder).

    # Placeholder for simulation step logic
    # This is where the population_state dictionary would be updated based on the step's processes.
    # For example:
    # population_state = apply_environmental_effects(population_state, current_uv_radiation, current_temperature)
    # population_state = apply_damage_and_repair(population_state)
    # population_state = apply_mutations(population_state)
    # ... and so on for all the processes outlined earlier.

    # Placeholder print statement
    if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
        print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.2f}째C")

    # Add checkpointing logic
    # Only save if population_state is defined
    if population_state is not None and step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
        if step == NUM_STEPS:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
        else:
             checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

        try:
            with open(checkpoint_filename, 'wb') as f:
                pickle.dump(population_state, f)
            print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
        except Exception as e:
            print(f"Error saving checkpoint at step {step}: {e}")
    elif population_state is None:
         print(f"Warning: population_state is None at step {step}. Skipping checkpoint save.")


print("Simulation loop finished.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Simulating step 6000/10000, UV: 0.12, Temp: 33.54째C
Simulating step 7000/10000, UV: 0.65, Temp: 29.50째C
Simulating step 8000/10000, UV: 0.44, Temp: 20.78째C
Simulating step 9000/10000, UV: 0.05, Temp: 22.53째C
Simulating step 10000/10000, UV: 0.53, Temp: 12.18째C
Simulation loop finished.


In [None]:
import pickle
import os
import jax
import jax.numpy as jnp
import numpy as np # Import numpy for initial state creation
import re
import jax.lax # Explicitly import jax.lax

# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved - Explicitly define it here
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             # We need NUM_STEPS. Let's try to extract it from the filename if available
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
population_state = None # Initialize population_state to None
start_step = 0

if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation with re-initialized state.")
    except Exception as e:
        print(f"Error loading last checkpoint {last_file_path}: {e}. Starting new simulation with re-initialized state.")


# If population_state is still None (meaning no checkpoints were loaded or loading failed),
# re-initialize it.
if population_state is None:
    print(f"No valid checkpoints found or loaded. Starting new simulation from step {start_step} with re-initialized state.")
    # Define simulation parameters (assuming these are consistent)
    NUM_AGENTS = 50000
    SEQUENCE_LENGTH = 100

    # Define integer representations for bases (assuming these are consistent)
    BASE_A = 0
    BASE_U = 1
    BASE_G = 2
    BASE_C = 3

    # Define protoribosome states (assuming these are consistent)
    PROTO_STATE_ACTIVE = 0
    PROTO_STATE_ARRESTED = 1
    PROTO_STATE_INACTIVE = 2

    # Define damaged base representations
    DAMAGED_G = 8 # Representing 8-oxoguanine
    DAMAGED_PYRIMIDINE_DIMER = 99 # Representing pyrimidine dimers (simplified as a single type)

    # Define constants for environmental cycles (re-defined here for initial state)
    STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
    DAYS_PER_YEAR = 365
    STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR # Ensure consistency

    # Environmental parameters (can be adjusted)
    MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
    MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
    MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
    MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
    SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
    DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

    BASE_TEMP = 20.0 # Base temperature
    TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
    TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
    TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
    TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Corrected typo here


    # Re-initialize JAX random key for initial state if needed (using a fixed seed for reproducibility of initial state)
    initial_key = jax.random.PRNGKey(42)
    initial_key, strategy_key, rna_key, proto_key = jax.random.split(initial_key, 4)

    # Re-initialize population state structure
    initial_strategies = jax.random.randint(strategy_key, (NUM_AGENTS,), 0, 2)

    # Re-initialize RNA sequences with random canonical bases (0-3 representing A, U, G, C)
    initial_rna_sequences = np.random.randint(0, 4, size=(NUM_AGENTS, SEQUENCE_LENGTH), dtype=np.int8)

    # Re-initialize protoribosome states (e.g., all start as ACTIVE)
    initial_protoribosome_states = np.full(NUM_AGENTS, PROTO_STATE_ACTIVE, dtype=np.int8)

    # Initial state should not have damage or stop codons
    population_state = {
        'strategy': np.array(initial_strategies),
        'rna_damage': np.zeros((NUM_AGENTS, SEQUENCE_LENGTH), dtype=np.float32), # Keep general damage level for simplicity alongside damaged base type
        'has_stop_codon': jnp.zeros(NUM_AGENTS, dtype=bool),
        # 'status': np.ones(NUM_AGENTS, dtype=np.int8), # Remove status, agents are implicitly "present" unless removed
        'rna_sequence': initial_rna_sequences, # RNA sequence now stores base type OR damaged base type
        'protoribosome_state': initial_protoribosome_states,
        'protoribosome_arrest_timer': jnp.zeros(NUM_AGENTS, dtype=jnp.int32), # Initialize timer
        'autocatalysis_failure_timer': jnp.zeros(NUM_AGENTS, dtype=jnp.int32) # Timer for consecutive autocatalysis failures
    }
    start_step = 0 # Ensure start_step is 0 if re-initializing

    # Save the initial state if starting from step 0
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    # Check if the initial state file already exists to avoid overwriting on accidental re-run
    if not os.path.exists(initial_state_filename):
         try:
              with open(initial_state_filename, 'wb') as f:
                  pickle.dump(population_state, f)
              print(f"Saved initial state to {initial_state_filename}")
         except Exception as e:
              print(f"Error saving initial state: {e}")
    else:
         print(f"Initial state file already exists: {initial_state_filename}. Not overwriting.")


# Define constants for environmental cycles (Ensure these are consistent with re-initialization)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY # Corrected typo here

# --- Simulation Parameters for Step Logic ---
# Damage rate parameters for converting canonical bases to damaged bases based on UV
UV_DAMAGE_RATE_G_TO_8OXOG = 0.001 # Probability per G per UV unit per step to become 8-oxoG
UV_DAMAGE_RATE_PYRIMIDINE_DIMER = 0.002 # Probability per adjacent pyrimidine pair per UV unit per step to become a dimer (simplified)

# Repair rate parameters for converting damaged bases back to canonical bases
REPAIR_RATE_8OXOG = 0.01 # Probability per 8-oxoG per step to be repaired
REPAIR_RATE_PYRIMIDINE_DIMER = 0.008 # Probability per dimer per step to be repaired

# Error-prone repair rate parameters
ERROR_PRONE_RATE_8OXOG = 0.1 # Probability that 8-oxoG repair is error-prone leading to mutation
ERROR_PRONE_RATE_PYRIMIDINE_DIMER = 0.15 # Probability that dimer repair is error-prone leading to mutation


# Strategy-dependent parameters
# Cautious strategy (0) has higher repair rate and lower damage threshold for arrest
CAUTIOUS_REPAIR_MULTIPLIER = 1.5 # Cautious agents repair 1.5x faster
RECKLESS_REPAIR_MULTIPLIER = 1.0 # Reckless agents repair at base rate
CAUTIOUS_ARREST_THRESHOLD_MULTIPLIER = 0.8 # Cautious agents arrest at 80% of base threshold
RECKLESS_ARREST_THRESHOLD_MULTIPLIER = 1.0 # Reckless agents arrest at base threshold

# Stop codon definitions (UAA, UAG, UGA) as integer sequences
STOP_CODON_UAA = jnp.array([BASE_U, BASE_A, BASE_A], dtype=jnp.int8) # 1, 0, 0
STOP_CODON_UAG = jnp.array([BASE_U, BASE_A, BASE_G], dtype=jnp.int8) # 1, 0, 2
STOP_CODON_UGA = jnp.array([BASE_U, BASE_G, BASE_A], dtype=jnp.int8) # 1, 2, 0

# Protoribosome state transition parameters
DAMAGE_BASE_COUNT_ARREST_THRESHOLD = 5 # Example: arrest if more than 5 damaged bases
ARREST_DURATION_STEPS = 10 # Number of steps an arrested protoribosome stays arrested before potentially becoming inactive (simplified)
REINITIATION_RATE = 0.01 # Probability per step for inactive protoribosomes to become active if conditions allow

# Autocatalysis (Reproduction/Failure) Parameters
# Probability of successful autocatalysis per step for agents with ACTIVE protoribosomes
AUTOCATALYSIS_SUCCESS_RATE_ACTIVE = 0.05
# Probability of failure per step for agents with INACTIVE protoribosomes
AUTOCATALYSIS_FAILURE_RATE_INACTIVE = 0.1
# Probability of failure per step for agents with ARRESTED protoribosomes
AUTOCATALYSIS_FAILURE_RATE_ARRESTED = 0.05
# Additional failure probability per step for agents with very high damage
AUTOCATALYSIS_FAILURE_RATE_HIGH_DAMAGE = 0.02
# Damage threshold for triggering additional failure probability (based on average damage per base)
HIGH_DAMAGE_FAILURE_THRESHOLD = 0.5 # Example threshold based on average damage per base

# Foraging/Removal Parameters
FORAGING_RATE_FAILED = 0.001 # Probability per step for a failed agent to be removed
FAILURE_DURATION_FORAGING_THRESHOLD = 50 # Number of consecutive failure steps before foraging risk increases (simplified)


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

# Ensure population_state is not None before starting the loop
if population_state is not None:
    # Initialize random key for stochastic processes within the loop
    # Use a key derived from the step number for reproducibility across steps
    key = jax.random.PRNGKey(start_step)

    # Ensure timers are in population_state if resuming from an older checkpoint structure
    if 'protoribosome_arrest_timer' not in population_state:
         population_state['protoribosome_arrest_timer'] = jnp.zeros(NUM_AGENTS, dtype=jnp.int32)
    if 'autocatalysis_failure_timer' not in population_state:
         population_state['autocatalysis_failure_timer'] = jnp.zeros(NUM_AGENTS, dtype=jnp.int32)


    for step in range(start_step, NUM_STEPS + 1):
        # Initialize num_error_prone_mutations for this step
        num_error_prone_mutations = 0

        # Update the random key for the current step
        key, step_key = jax.random.split(key)

        # Calculate environmental variables for the current step
        current_uv_radiation = calculate_uv(step)
        current_temperature = calculate_temperature(step)

        # --- 1. Environmental Factors & RNA Damage Application (Specific Damaged Bases) ---
        # Apply UV damage to RNA sequences for all agents
        current_rna_sequence = population_state['rna_sequence']

        # Identify canonical G bases
        canonical_g_mask = (current_rna_sequence == BASE_G)

        # Identify pyrimidine bases (U or C)
        pyrimidine_mask = ((current_rna_sequence == BASE_U) | (current_rna_sequence == BASE_C))

        # Identify adjacent pyrimidine pairs (simplified: check if a pyrimidine is next to another pyrimidine)
        pyrimidine_mask_shifted = jnp.roll(pyrimidine_mask, -1, axis=1)
        pyrimidine_mask_shifted = jnp.where(jnp.arange(SEQUENCE_LENGTH) == SEQUENCE_LENGTH - 1, False, pyrimidine_mask_shifted)
        potential_dimer_site_mask = pyrimidine_mask & pyrimidine_mask_shifted

        # Determine which canonical Gs become 8-oxoG due to UV
        key, uv_g_key = jax.random.split(step_key)
        uv_g_roll = jax.random.uniform(uv_g_key, shape=canonical_g_mask.shape)
        g_to_8oxog_mask = canonical_g_mask & (uv_g_roll < current_uv_radiation * UV_DAMAGE_RATE_G_TO_8OXOG)

        # Determine which potential pyrimidine dimer sites become dimers due to UV
        key, uv_pyr_key = jax.random.split(key)
        uv_pyr_roll = jax.random.uniform(uv_pyr_key, shape=potential_dimer_site_mask.shape)
        to_pyrimidine_dimer_mask = potential_dimer_site_mask & (uv_pyr_roll < current_uv_radiation * UV_DAMAGE_RATE_PYRIMIDINE_DIMER)

        # Update RNA sequence: canonical bases become damaged bases
        new_rna_sequence = jnp.copy(current_rna_sequence)
        new_rna_sequence = jnp.where(g_to_8oxog_mask, jnp.array(DAMAGED_G, dtype=jnp.int8), new_rna_sequence)
        new_rna_sequence = jnp.where(to_pyrimidine_dimer_mask, jnp.array(DAMAGED_PYRIMIDINE_DIMER, dtype=jnp.int8), new_rna_sequence)

        population_state['rna_sequence'] = new_rna_sequence

        # --- 2. RNA Repair (Targeting Damaged Bases) ---
        # Identify damaged bases
        damaged_g_mask = (population_state['rna_sequence'] == DAMAGED_G)
        damaged_dimer_mask = (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER)

        # Determine repair probability based on strategy
        repair_rate_multiplier = jnp.where(
            population_state['strategy'] == 0, # Cautious strategy
            jnp.array(CAUTIOUS_REPAIR_MULTIPLIER),
            jnp.array(RECKLESS_REPAIR_MULTIPLIER) # Reckless strategy
        )[:, None] # Expand dims for broadcasting

        # Determine which damaged Gs and dimers are repaired
        key, repair_roll_key = jax.random.split(key) # Generate repair roll for this step
        repair_roll = jax.random.uniform(repair_roll_key, shape=damaged_g_mask.shape) # Use shape of damage mask

        g_repair_mask = damaged_g_mask & (repair_roll < REPAIR_RATE_8OXOG * repair_rate_multiplier) # Apply strategy multiplier
        dimer_repair_mask = damaged_dimer_mask & (repair_roll < REPAIR_RATE_PYRIMIDINE_DIMER * repair_rate_multiplier) # Apply strategy multiplier


        # --- 3. Implement Mutations (Error-Prone Repair of Damaged Bases) ---
        # Mutations occur when repair is error-prone.
        # Determine which repaired damaged Gs result in error-prone repair
        key, error_prone_g_key = jax.random.split(key)
        error_prone_g_roll = jax.random.uniform(error_prone_g_key, shape=g_repair_mask.shape)
        error_prone_g_repair_mask = g_repair_mask & (error_prone_g_roll < ERROR_PRONE_RATE_8OXOG)

        # Determine which repaired dimers result in error-prone repair
        key, error_prone_dimer_key = jax.random.split(key)
        error_prone_dimer_roll = jax.random.uniform(error_prone_dimer_key, shape=dimer_repair_mask.shape)
        error_prone_dimer_repair_mask = dimer_repair_mask & (error_prone_dimer_roll < ERROR_PRONE_RATE_PYRIMIDINE_DIMER)

        # Combine masks for all sites undergoing error-prone repair
        error_prone_repair_mask = error_prone_g_repair_mask | error_prone_dimer_repair_mask

        # At sites undergoing error-prone repair, the base changes to a random canonical base (0-3).
        num_error_prone_mutations = jnp.sum(error_prone_repair_mask) # Calculate before the if block

        if num_error_prone_mutations > 0:
             key, mutation_type_key = jax.random.split(key)
             # Generate random new canonical bases (0-3) for the mutation sites
             random_new_bases = jax.random.randint(mutation_type_key, shape=(num_error_prone_mutations,), minval=0, maxval=4)

             # Get the indices where mutations occur.
             mutation_indices = jnp.argwhere(error_prone_repair_mask, size=num_error_prone_mutations, fill_value=-1) # Get (agent_idx, base_idx)

             # Update the rna_sequence at the error-prone repair sites with random canonical bases
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 mutation_indices,
                 random_new_bases,
                 scatter_dims=(0, 1) # Scatter along agent and sequence dimensions
             )

        # For damaged bases that were repaired accurately (not error-prone), they revert to their canonical form.
        # Accurate repair of 8-oxoG (8) goes back to G (2).
        # Accurate repair of dimer (99) goes back to a random pyrimidine (U=1 or C=3) for simplicity.

        accurate_g_repair_mask = g_repair_mask & ~error_prone_g_repair_mask
        accurate_dimer_repair_mask = dimer_repair_mask & ~error_prone_dimer_repair_mask

        # Apply accurate G repair: 8-oxoG (8) becomes G (2)
        population_state['rna_sequence'] = jnp.where(
            accurate_g_repair_mask,
            jnp.array(BASE_G, dtype=jnp.int8),
            population_state['rna_sequence']
        )

        # Apply accurate dimer repair: Dimer (99) becomes a random pyrimidine (U=1 or C=3)
        num_accurate_dimer_repairs = jnp.sum(accurate_dimer_repair_mask)
        if num_accurate_dimer_repairs > 0:
             key, accurate_dimer_base_key = jax.random.split(key)
             # Generate random pyrimidine bases (1 or 3)
             random_pyr_bases = jax.random.randint(accurate_dimer_base_key, shape=(num_accurate_dimer_repairs,), minval=0, maxval=2) # 0 or 1
             # Map 0 to U (1) and 1 to C (3)
             random_canonical_pyr = jnp.where(random_pyr_bases == 0, jnp.array(BASE_U, dtype=jnp.int8), jnp.array(BASE_C, dtype=jnp.int8))

             # Get the indices where accurate dimer repairs occur.
             accurate_dimer_indices = jnp.argwhere(accurate_dimer_repair_mask, size=num_accurate_dimer_repairs, fill_value=-1) # Get (agent_idx, base_idx)

             # Update the rna_sequence at the accurate dimer repair sites
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 accurate_dimer_indices,
                 random_canonical_pyr,
                 scatter_dims=(0, 1) # Scatter along agent and sequence dimensions
             )


        # --- 4. Determine Stop Codons ---
        # Check for the presence of stop codons in the RNA sequence for each agent.
        # This needs to be done after mutations, as mutations can create or remove stop codons.
        # Damaged bases (8 and 99) are not canonical bases and would likely block translation if encountered by the protoribosome.
        # For stop codon check, we should only look for canonical bases.

        # Reset has_stop_codon for all agents before checking
        population_state['has_stop_codon'] = jnp.array(False, shape=NUM_AGENTS)

        # Check for each stop codon in each agent's RNA sequence
        # This part is not easily jittable due to variable inner loop length (sequence scanning)
        # For clarity and direct implementation of the concept, let's keep this for loop over agents for now.

        has_stop_codon_mask = jnp.zeros(NUM_AGENTS, dtype=bool)

        if SEQUENCE_LENGTH >= 3:
            for i in range(NUM_AGENTS):
                 rna = population_state['rna_sequence'][i]
                 is_stop = False
                 # Iterate through possible codon start positions
                 for j in range(SEQUENCE_LENGTH - 2):
                     codon = rna[j : j + 3]
                     # Check if all bases in the codon are canonical (0-3)
                     if jnp.all((codon >= 0) & (codon <= 3)):
                          # Check for stop codons
                         if jnp.array_equal(codon, STOP_CODON_UAA) or \
                            jnp.array_equal(codon, STOP_CODON_UAG) or \
                            jnp.array_equal(codon, STOP_CODON_UGA):
                             is_stop = True
                             break # Found a stop codon, no need to check further for this agent
                     # If any base in the codon is a damaged base (8 or 99), it's not a valid stop codon here
                     # and would likely cause protoribosome arrest (handled in the next step).

                 has_stop_codon_mask = has_stop_codon_mask.at[i].set(is_stop)

        population_state['has_stop_codon'] = has_stop_codon_mask


        # --- 5. Implement Protoribosome State Updates ---
        # Update protoribosome states based on current state, damage, and stop codons.
        current_proto_state = population_state['protoribosome_state']
        new_proto_state = jnp.copy(current_proto_state) # Start with current states

        # Calculate number of damaged bases per agent
        num_damaged_bases_per_agent = jnp.sum(
            (population_state['rna_sequence'] == DAMAGED_G) | (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER),
            axis=1
        )

        # Determine damage arrest threshold based on strategy
        damage_arrest_threshold_per_agent = jnp.where(
            population_state['strategy'] == 0, # Cautious strategy
            jnp.array(DAMAGE_BASE_COUNT_ARREST_THRESHOLD * CAUTIOUS_ARREST_THRESHOLD_MULTIPLIER),
            jnp.array(DAMAGE_BASE_COUNT_ARREST_THRESHOLD * RECKLESS_ARREST_THRESHOLD_MULTIPLIER) # Reckless strategy
        )


        # Transition rules:
        # ACTIVE (0) -> ARRESTED (1): if has_stop_codon is True OR num_damaged_bases_per_agent > strategy-dependent threshold
        # ARRESTED (1) -> INACTIVE (2) : if protoribosome_arrest_timer exceeds ARREST_DURATION_STEPS
        # ARRESTED (1) -> ACTIVE (0): if damaged bases are repaired AND stop codon removed (simplified: if no stop codon AND num_damaged_bases_per_agent <= strategy-dependent threshold)
        # INACTIVE (2) -> ACTIVE (0): (Optional, re-initiation) after a period of inactivity (e.g., if timer exceeds a threshold) AND conditions for arrest are no longer met

        # Update protoribosome arrest timer for agents currently in ARRESTED state
        arrested_mask = (current_proto_state == PROTO_STATE_ARRESTED)
        protoribosome_arrest_timer = jnp.where(
            arrested_mask,
            population_state['protoribosome_arrest_timer'] + 1, # Use protoribosome_arrest_timer from population_state
            jnp.array(0, dtype=jnp.int32) # Timer resets for other states
        )

        # Transition from ACTIVE to ARRESTED
        active_to_arrested_mask = (current_proto_state == PROTO_STATE_ACTIVE) & \
                                   (population_state['has_stop_codon'] | (num_damaged_bases_per_agent > damage_arrest_threshold_per_agent))

        new_proto_state = jnp.where(
            active_to_arrested_mask,
            jnp.array(PROTO_STATE_ARRESTED, dtype=jnp.int8), # Ensure dtype matches population_state
            new_proto_state
        )


        # Transition from ARRESTED to INACTIVE
        # This transition is based on the arrest timer
        arrested_to_inactive_mask = (current_proto_state == PROTO_STATE_ARRESTED) & \
                                      (protoribosome_arrest_timer >= jnp.array(ARREST_DURATION_STEPS))

        new_proto_state = jnp.where(
            arrested_to_inactive_mask,
            jnp.array(PROTO_STATE_INACTIVE, dtype=jnp.int8), # Ensure dtype matches population_state
            new_proto_state
        )

        # Transition from ARRESTED back to ACTIVE (Recovery)
        # Conditions: No stop codon AND damage is below the strategy-dependent threshold
        arrested_to_active_mask = (current_proto_state == PROTO_STATE_ARRESTED) & \
                                  (~population_state['has_stop_codon'] & (num_damaged_bases_per_agent <= damage_arrest_threshold_per_agent))

        new_proto_state = jnp.where(
            arrested_to_active_mask,
            jnp.array(PROTO_STATE_ACTIVE, dtype=jnp.int8),
            new_proto_state
        )
        # Reset timer for agents transitioning from ARRESTED to ACTIVE
        protoribosome_arrest_timer = jnp.where(
            arrested_to_active_mask,
            jnp.array(0, dtype=jnp.int32),
            protoribosome_arrest_timer
        )


        # Transition from INACTIVE back to ACTIVE (Re-initiation)
        # Let's add a simple re-initiation probability for inactive protoribosomes
        # if conditions for arrest are not met.
        inactive_to_active_mask_conditions = (current_proto_state == PROTO_STATE_INACTIVE) & \
                                             (~population_state['has_stop_codon'] & (num_damaged_bases_per_agent <= damage_arrest_threshold_per_agent))

        # Add a re-initiation probability
        key, reinit_roll_key = jax.random.split(key)
        reinit_roll = jax.random.uniform(reinit_roll_key, shape=inactive_to_active_mask_conditions.shape)
        inactive_to_active_mask = inactive_to_active_mask_conditions & (reinit_roll < REINITIATION_RATE)


        new_proto_state = jnp.where(
            inactive_to_active_mask,
            jnp.array(PROTO_STATE_ACTIVE, dtype=jnp.int8),
            new_proto_state
        )
        # Reset timer for agents transitioning from INACTIVE to ACTIVE (timer is already 0)


        # Update the protoribosome_state in the population_state
        population_state['protoribosome_state'] = new_proto_state
        # Update the arrest timer in the population_state
        population_state['protoribosome_arrest_timer'] = protoribosome_arrest_timer


        # --- 7. Implement Autocatalysis (Success/Failure and Population Dynamics) ---
        # Determine which agents are capable of successful autocatalysis (reproduction)
        # Only agents with ACTIVE protoribosomes can potentially reproduce
        can_reproduce_mask = (population_state['protoribosome_state'] == PROTO_STATE_ACTIVE)
        key, success_roll_key = jax.random.split(key)
        success_roll = jax.random.uniform(success_roll_key, shape=can_reproduce_mask.shape)
        successfully_reproducing_agents_mask = can_reproduce_mask & (success_roll < AUTOCATALYSIS_SUCCESS_RATE_ACTIVE)

        num_successfully_reproducing_agents = jnp.sum(successfully_reproducing_agents_mask)

        # Update autocatalysis failure timer: increment for agents that did NOT reproduce, reset for those that did.
        autocatalysis_failure_timer = jnp.where(
            successfully_reproducing_agents_mask,
            jnp.array(0, dtype=jnp.int32), # Reset timer on successful reproduction
            population_state['autocatalysis_failure_timer'] + 1 # Increment timer on failure to reproduce
        )
        population_state['autocatalysis_failure_timer'] = autocatalysis_failure_timer


        # --- 8. Implement Foraging/Removal ---
        # Agents can be removed from the population (foraged) if they fail autocatalysis
        # consistently for a certain period, or based on other criteria.
        # Use the autocatalysis_failure_timer to determine foraging risk.

        # Identify agents at risk of foraging (failed for a certain duration)
        foraging_risk_mask = (population_state['autocatalysis_failure_timer'] >= jnp.array(FAILURE_DURATION_FORAGING_THRESHOLD))

        # Determine which agents are foraged
        key, foraging_roll_key = jax.random.split(key)
        foraging_roll = jax.random.uniform(foraging_roll_key, shape=foraging_risk_mask.shape)
        foraged_agents_mask = foraging_risk_mask & (foraging_roll < FORAGING_RATE_FAILED)

        num_foraged_agents = jnp.sum(foraged_agents_mask)

        # Implement Population Dynamics: Replace foraged agents with offspring from successfully reproducing agents
        # This maintains a constant population size (NUM_AGENTS).
        num_replacements_needed = num_foraged_agents
        num_potential_offspring = num_successfully_reproducing_agents

        # If there are reproducing agents and replacements needed, select reproducing agents
        if num_potential_offspring > 0 and num_replacements_needed > 0:
             key, offspring_selection_key = jax.random.split(key)
             # Select indices of successfully reproducing agents (potential parents)
             reproducing_indices = jnp.where(successfully_reproducing_agents_mask)[0]

             # Randomly select parents to produce offspring to replace foraged agents
             # Use replacement=True if num_replacements_needed > num_potential_offspring
             replace_with_replacement = num_replacements_needed > num_potential_offspring
             parent_indices_for_offspring = jax.random.choice(
                 offspring_selection_key,
                 reproducing_indices,
                 shape=(num_replacements_needed,),
                 replace=replace_with_replacement
             )

             # Get the state of the selected parent agents
             parent_rna_sequences = population_state['rna_sequence'][parent_indices_for_offspring]
             parent_strategies = population_state['strategy'][parent_indices_for_offspring]

             # Create offspring state (e.g., copy parent's RNA and strategy, reset damage/protoribosome state, reset timers)
             offspring_rna_sequences = jnp.copy(parent_rna_sequences) # Offspring inherit RNA
             offspring_strategies = jnp.copy(parent_strategies) # Offspring inherit strategy
             offspring_rna_damage = jnp.zeros_like(offspring_rna_sequences, dtype=jnp.float32) # Offspring start with no general damage
             # Important: Stop codons need to be re-checked for offspring based on inherited sequence in the next step
             offspring_has_stop_codon = jnp.zeros(num_replacements_needed, dtype=bool)
             offspring_protoribosome_state = jnp.full(num_replacements_needed, PROTO_STATE_ACTIVE, dtype=jnp.int8) # Offspring protoribosomes are active
             offspring_protoribosome_arrest_timer = jnp.zeros(num_replacements_needed, dtype=jnp.int32) # Offspring timer is 0
             offspring_autocatalysis_failure_timer = jnp.zeros(num_replacements_needed, dtype=jnp.int32) # Offspring timer is 0


             # Select indices of foraged agents to be replaced
             foraged_indices = jnp.where(foraged_agents_mask)[0]

             # Update the state of the foraged agents with the offspring state
             # Use jax.lax.scatter_nd_update to update multiple indices at once
             population_state['strategy'] = jax.lax.scatter_nd_update(
                 population_state['strategy'],
                 foraged_indices[:, None], # Indices need to be shape (N, 1) for scatter
                 offspring_strategies,
                 scatter_dims=(0,) # Scatter along the first dimension (agents)
             )
             population_state['rna_damage'] = jax.lax.scatter_nd_update(
                 population_state['rna_damage'],
                 foraged_indices[:, None, None], # Indices need to be shape (N, 1, 1) for scatter into 2D array
                 offspring_rna_damage,
                 scatter_dims=(0,)
             )
             population_state['has_stop_codon'] = jax.lax.scatter_nd_update(
                 population_state['has_stop_codon'],
                 foraged_indices[:, None],
                 offspring_has_stop_codon,
                 scatter_dims=(0,)
             )
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 foraged_indices[:, None, None], # Indices need to be shape (N, 1, 1) for scatter into 2D array
                 offspring_rna_sequences,
                 scatter_dims=(0,)
             )
             population_state['protoribosome_state'] = jax.lax.scatter_nd_update(
                 population_state['protoribosome_state'],
                 foraged_indices[:, None],
                 offspring_protoribosome_state,
                 scatter_dims=(0,)
             )
             population_state['protoribosome_arrest_timer'] = jax.lax.scatter_nd_update(
                 population_state['protoribosome_arrest_timer'],
                 foraged_indices[:, None],
                 offspring_protoribosome_arrest_timer,
                 scatter_dims=(0,)
             )
             population_state['autocatalysis_failure_timer'] = jax.lax.scatter_nd_update(
                 population_state['autocatalysis_failure_timer'],
                 foraged_indices[:, None],
                 offspring_autocatalysis_failure_timer,
                 scatter_dims=(0,)
             )

        # If no reproducing agents but foraging occurs, these spots are effectively removed (status=0 implicitly, or just replaced with empty/dummy state)
        # In this replacement model, if num_potential_offspring is 0, num_replacements_needed will also be 0 as no offspring can be created.
        # If num_successfully_reproducing_agents < num_foraged_agents, we replace as many as possible from reproducing agents.
        # The remaining foraged spots are effectively empty until future steps produce more offspring than failures.


        # --- Placeholder for other simulation step logic ---
        # (All core simulation step logic is now implemented)


        # Placeholder print statement
        if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
            # Calculate average damaged bases per agent for printing
            num_damaged_bases_per_agent = jnp.sum(
                (population_state['rna_sequence'] == DAMAGED_G) | (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER),
                axis=1
            )
            avg_damaged_bases_for_print = jnp.mean(num_damaged_bases_per_agent) if num_damaged_bases_per_agent.shape[0] > 0 else 0.0

            num_agents_with_stop_codon = jnp.sum(population_state['has_stop_codon']) # Count agents with stop codons
            # Count protoribosome states for all agents
            active_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_ACTIVE)
            arrested_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_ARRESTED)
            inactive_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_INACTIVE)

            # Count agents by strategy
            cautious_count = jnp.sum(population_state['strategy'] == 0)
            reckless_count = jnp.sum(population_state['strategy'] == 1)

            # Total agents remain constant at NUM_AGENTS in this replacement model


            print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.0f}째C, Avg Damaged Bases: {avg_damaged_bases_for_print:.2f}, Error-Prone Mutations: {num_error_prone_mutations}, Agents with Stop Codon: {num_agents_with_stop_codon}")
            print(f"  Protoribosome States: Active={active_protos_count}, Arrested={arrested_protos_count}, Inactive={inactive_protos_count}")
            print(f"  Agents by Strategy: Cautious={cautious_count}, Reckless={reckless_count}")
            print(f"  Autocatalysis: Successful={num_successfully_reproducing_agents}, Failed (this step): {NUM_AGENTS - num_successfully_reproducing_agents}, Foraged: {num_foraged_agents}")


        # Add checkpointing logic
        # Only save if population_state is defined
        if population_state is not None and step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
            if step == NUM_STEPS:
                 checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
            else:
                 checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

            try:
                with open(checkpoint_filename, 'wb') as f:
                    pickle.dump(population_state, f)
                print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
            except Exception as e:
                print(f"Error saving checkpoint at step {step}: {e}")

else:
    print("Simulation loop skipped because population_state is not defined.")


print("Simulation loop finished.")

Ensured checkpoint directory exists: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50
Successfully loaded last checkpoint: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/initial_state.pkl (Step: 0)
Resuming simulation from step 1 with loaded state.


NameError: name 'BASE_U' is not defined

In [None]:
import pandas as pd
import numpy as np

# Assume simulation_data and loaded_steps are available from a previous cell
# Example placeholder if not loaded (replace with actual loading code)
try:
    simulation_data
    loaded_steps
except NameError:
    print("simulation_data or loaded_steps not found. Please run the cell that loads simulation data first.")
    # You might need to load your checkpoint data here if it's not already loaded.
    # Example loading (you'll need to adjust the path and logic):
    # simulation_data = []
    # loaded_steps = []
    # # Assuming you have checkpoint files like checkpoint_step_X.pkl
    # import os
    # import pickle
    # CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50"
    # checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith('checkpoint_step_') or f.startswith('final_state_step_')])
    # for filename in checkpoint_files:
    #     step = int(filename.split('_')[-1].split('.')[0])
    #     filepath = os.path.join(CHECKPOINT_DIR, filename)
    #     with open(filepath, 'rb') as f:
    #         state = pickle.load(f)
    #         simulation_data.append(state)
    #         loaded_steps.append(step)
    # print(f"Loaded {len(simulation_data)} simulation states.")


if 'simulation_data' in locals() and simulation_data: # Check if simulation_data is defined and not empty
    # Create lists to store the extracted metrics
    steps = []
    cautious_pops = []
    reckless_pops = []
    active_protos = []
    arrested_protos = []
    inactive_protos = []
    avg_damaged_bases = [] # Updated name to reflect damaged bases count

    # Define damaged base representations (ensure consistency with simulation code)
    DAMAGED_G = 8
    DAMAGED_PYRIMIDINE_DIMER = 99

    # Iterate through the loaded data
    for i, state in enumerate(simulation_data):
        step = loaded_steps[i] # Get the step number
        steps.append(step)

        # Calculate population counts and other metrics from the state dictionary
        # In the updated simulation, agents are implicitly present if not foraged.
        # Let's count based on presence in the state dictionary, assuming fixed NUM_AGENTS.
        # If NUM_AGENTS can change, the data structure needs to track presence explicitly.
        # Assuming fixed NUM_AGENTS for now, all agents in the state are considered.

        cautious_pop = np.sum(state['strategy'] == 0)
        reckless_pop = np.sum(state['strategy'] == 1)

        # Count protoribosome states
        active_proto = np.sum(state['protoribosome_state'] == PROTO_STATE_ACTIVE)
        arrested_proto = np.sum(state['protoribosome_state'] == PROTO_STATE_ARRESTED)
        inactive_proto = np.sum(state['protoribosome_state'] == PROTO_STATE_INACTIVE)

        # Calculate average number of damaged bases per agent
        num_damaged_bases_per_agent = np.sum(
             (state['rna_sequence'] == DAMAGED_G) | (state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER),
             axis=1
         )
        avg_damage = np.mean(num_damaged_bases_per_agent) if num_damaged_bases_per_agent.shape[0] > 0 else 0.0

        cautious_pops.append(cautious_pop)
        reckless_pops.append(reckless_pop)
        active_protos.append(active_proto)
        arrested_protos.append(arrested_proto)
        inactive_protos.append(inactive_proto)
        avg_damaged_bases.append(avg_damage)


    # Create a pandas DataFrame
    analysis_df = pd.DataFrame({
        'Step': steps,
        'Cautious Population': cautious_pops,
        'Reckless Population': reckless_pops,
        'Active Protoribosomes': active_protos,
        'Arrested Protoribosomes': arrested_protos,
        'Inactive Protoribosomes': inactive_protos,
        'Average Damaged Bases per Agent': avg_damaged_bases
    })

    print("Data extraction and DataFrame creation complete.")
    display(analysis_df.head())

else:
    print("No simulation data available to process.")

simulation_data or loaded_steps not found. Please run the cell that loads simulation data first.
No simulation data available to process.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set plot style
sns.set_theme(style="whitegrid")

# Plot Population Dynamics by Strategy
plt.figure(figsize=(12, 6))
plt.plot(analysis_df['Step'], analysis_df['Cautious Population'], label='Cautious')
plt.plot(analysis_df['Step'], analysis_df['Reckless Population'], label='Reckless')
plt.xlabel('Simulation Step')
plt.ylabel('Population Count')
plt.title('Population Dynamics of Cautious vs. Reckless Strategies')
plt.legend()
plt.show()

# Plot Protoribosome States Over Time
plt.figure(figsize=(12, 6))
plt.plot(analysis_df['Step'], analysis_df['Active Protoribosomes'], label='Active')
plt.plot(analysis_df['Step'], analysis_df['Arrested Protoribosomes'], label='Arrested')
plt.plot(analysis_df['Step'], analysis_df['Inactive Protoribosomes'], label='Inactive')
plt.xlabel('Simulation Step')
plt.ylabel('Count')
plt.title('Protoribosome States Over Time')
plt.legend()
plt.show()

# Plot Average Damaged Bases per Agent Over Time
plt.figure(figsize=(12, 6))
plt.plot(analysis_df['Step'], analysis_df['Average Damaged Bases per Agent'])
plt.xlabel('Simulation Step')
plt.ylabel('Average Number of Damaged Bases')
plt.title('Average Damaged Bases per Agent Over Time')
plt.show()

print("Plots generated.")

NameError: name 'analysis_df' is not defined

<Figure size 1200x600 with 0 Axes>

In [None]:
import os

# Check if the Google Drive mount point exists
if os.path.exists('/content/drive'):
    print("'/content/drive' exists. Checking contents...")
    # Check if MyDrive is accessible, which indicates a successful mount
    if os.path.exists('/content/drive/MyDrive'):
        print("Google Drive is likely mounted successfully.")
        # Optionally list contents of MyDrive to be sure
        # print("Contents of MyDrive:")
        # print(os.listdir('/content/drive/MyDrive'))
    else:
        print("'/content/drive' exists, but 'MyDrive' was not found. Drive might not be fully mounted or path is incorrect.")
else:
    print("'/content/drive' mount point does not exist. Google Drive is not mounted.")

'/content/drive' exists. Checking contents...
Google Drive is likely mounted successfully.


In [None]:
import pickle
import os
import jax
import jax.numpy as jnp
import numpy as np # Import numpy for initial state creation
import re
import jax.lax # Explicitly import jax.lax

# Define simulation parameters (Always defined)
NUM_AGENTS = 50000
SEQUENCE_LENGTH = 100

# Define integer representations for bases (Always defined)
BASE_A = 0
BASE_U = 1
BASE_G = 2
BASE_C = 3

# Define protoribosome states (Always defined)
PROTO_STATE_ACTIVE = 0
PROTO_STATE_ARRESTED = 1
PROTO_STATE_INACTIVE = 2

# Define damaged base representations (Always defined)
DAMAGED_G = 8 # Representing 8-oxoguanine
DAMAGED_PYRIMIDINE_DIMER = 99 # Representing pyrimidine dimers (simplified as a single type)

# Define constants for environmental cycles (Always defined)
STEPS_PER_DAY = 24 # Assuming 24 steps represent one day
DAYS_PER_YEAR = 365
STEPS_PER_YEAR = STEPS_PER_DAY * DAYS_PER_YEAR

# Environmental parameters (can be adjusted, Always defined)
MAX_UV_DAILY = 1.0 # Maximum UV intensity during the day
MIN_UV_DAILY = 0.1 # Minimum UV intensity (night)
MAX_UV_SEASONAL_FACTOR = 1.5 # UV is 1.5 times higher in summer peak
MIN_UV_SEASONAL_FACTOR = 0.5 # UV is 0.5 times lower in winter peak
SUMMER_SOLSTICE_DAY = DAYS_PER_YEAR // 2 # Mid-year is summer peak
DAY_CYCLE_OFFSET = STEPS_PER_DAY // 2 # Peak UV at midday

BASE_TEMP = 20.0 # Base temperature
TEMP_DAILY_AMPLITUDE = 5.0 # Daily temperature swing
TEMP_SEASONAL_AMPLITUDE = 10.0 # Seasonal temperature swing
TEMP_DAILY_OFFSET = STEPS_PER_DAY // 3 * 2 # Peak temp in afternoon
TEMP_SEASONAL_OFFSET = SUMMER_SOLSTICE_DAY * STEPS_PER_DAY


# Define checkpoint interval
CHECKPOINT_INTERVAL = 1000 # Save checkpoint every 1000 steps

# Define the directory where checkpoints are saved - Explicitly define it here
CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50" # Ensure this matches the directory used in the simulation

# Create the directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"Ensured checkpoint directory exists: {CHECKPOINT_DIR}")

# List all files in the checkpoint directory
all_files = os.listdir(CHECKPOINT_DIR)

# Filter for checkpoint files and sort them by step number
checkpoint_files_to_load = []
# Regex to match checkpoint files (initial_state.pkl, checkpoint_step_X.pkl, final_state_step_X.pkl)
checkpoint_pattern = re.compile(r'^(initial_state\.pkl|checkpoint_step_(\d+)\.pkl|final_state_step_(\d+)\.pkl)$')

for filename in all_files:
    match = checkpoint_pattern.match(filename)
    if match:
        file_path = os.path.join(CHECKPOINT_DIR, filename)
        # Extract step number for sorting
        if 'initial_state' in filename:
            step = 0
        elif 'final_state' in filename:
             # Assuming final state filename is final_state_step_NUM_STEPS.pkl
             try:
                 step = int(match.group(3)) # Capture group 3 for final_state_step_(\d+).pkl
             except (ValueError, TypeError):
                  # Fallback to a large number for sorting if extraction fails
                  step = 999999999
        else:
            step = int(match.group(2)) # Capture group 2 for checkpoint_step_(\d+).pkl
        checkpoint_files_to_load.append((step, file_path))

# Sort files by step number
checkpoint_files_to_load.sort(key=lambda x: x[0])

# Load the last checkpoint if any exist
population_state = None # Initialize population_state to None
start_step = 0

if checkpoint_files_to_load:
    last_step, last_file_path = checkpoint_files_to_load[-1]
    try:
        with open(last_file_path, 'rb') as f:
            population_state = pickle.load(f)
            start_step = last_step + 1
            print(f"Successfully loaded last checkpoint: {last_file_path} (Step: {last_step})")
            print(f"Resuming simulation from step {start_step} with loaded state.")
    except FileNotFoundError:
        print(f"Error: Last checkpoint file not found at {last_file_path}. Starting new simulation with re-initialized state.")
    except Exception as e:
        print(f"Error loading last checkpoint {last_file_path}: {e}. Starting new simulation with re-initialized state.")


# If population_state is still None (meaning no checkpoints were loaded or loading failed),
# re-initialize it.
if population_state is None:
    print(f"No valid checkpoints found or loaded. Starting new simulation from step {start_step} with re-initialized state.")
    # Note: Simulation parameters, base definitions, protoribosome states,
    # damaged base definitions, and environmental constants are now defined outside this block.

    # Re-initialize JAX random key for initial state if needed (using a fixed seed for reproducibility of initial state)
    initial_key = jax.random.PRNGKey(42)
    initial_key, strategy_key, rna_key, proto_key = jax.random.split(initial_key, 4)

    # Re-initialize population state structure
    initial_strategies = jax.random.randint(strategy_key, (NUM_AGENTS,), 0, 2)

    # Re-initialize RNA sequences with random canonical bases (0-3 representing A, U, G, C)
    initial_rna_sequences = np.random.randint(0, 4, size=(NUM_AGENTS, SEQUENCE_LENGTH), dtype=np.int8)

    # Re-initialize protoribosome states (e.g., all start as ACTIVE)
    initial_protoribosome_states = np.full(NUM_AGENTS, PROTO_STATE_ACTIVE, dtype=np.int8)

    # Initial state should not have damage or stop codons
    population_state = {
        'strategy': np.array(initial_strategies),
        'rna_damage': np.zeros((NUM_AGENTS, SEQUENCE_LENGTH), dtype=np.float32), # Keep general damage level for simplicity alongside damaged base type
        'has_stop_codon': jnp.zeros(NUM_AGENTS, dtype=bool),
        # 'status': np.ones(NUM_AGENTS, dtype=np.int8), # Remove status, agents are implicitly "present" unless removed
        'rna_sequence': initial_rna_sequences, # RNA sequence now stores base type OR damaged base type
        'protoribosome_state': initial_protoribosome_states,
        'protoribosome_arrest_timer': jnp.zeros(NUM_AGENTS, dtype=jnp.int32), # Initialize timer
        'autocatalysis_failure_timer': jnp.zeros(NUM_AGENTS, dtype=jnp.int32) # Timer for consecutive autocatalysis failures
    }
    start_step = 0 # Ensure start_step is 0 if re-initializing

    # Save the initial state if starting from step 0
    initial_state_filename = os.path.join(CHECKPOINT_DIR, 'initial_state.pkl')
    # Check if the initial state file already exists to avoid overwriting on accidental re-run
    if not os.path.exists(initial_state_filename):
         try:
              with open(initial_state_filename, 'wb') as f:
                  pickle.dump(population_state, f)
              print(f"Saved initial state to {initial_state_filename}")
         except Exception as e:
              print(f"Error saving initial state: {e}")
    else:
         print(f"Initial state file already exists: {initial_state_filename}. Not overwriting.")


# Function to calculate UV radiation based on step number
def calculate_uv(step, max_uv_daily=MAX_UV_DAILY, min_uv_daily=MIN_UV_DAILY,
                 max_uv_seasonal_factor=MAX_UV_SEASONAL_FACTOR, min_uv_seasonal_factor=MIN_UV_SEASONAL_FACTOR,
                 steps_per_day=STEPS_PER_DAY, steps_per_year=STEPS_PER_YEAR,
                 day_cycle_offset=DAY_CYCLE_OFFSET):
    """Calculates UV radiation based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking around midday)
    day_of_year_step = step % steps_per_year
    step_of_day = step % steps_per_day
    daily_uv = ((max_uv_daily - min_uv_daily) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_day) * (step_of_day - day_cycle_offset))) + min_uv_daily

    # Seasonal cycle (sinusoidal, peaking in summer)
    seasonal_factor = ((max_uv_seasonal_factor - min_uv_seasonal_factor) / 2) * (1 + jnp.cos((2 * jnp.pi / steps_per_year) * (day_of_year_step - TEMP_SEASONAL_OFFSET))) + min_uv_seasonal_factor

    # Combine daily and seasonal cycles
    current_uv = daily_uv * seasonal_factor
    return current_uv

# Function to calculate temperature based on step number
def calculate_temperature(step, base_temp=BASE_TEMP, temp_daily_amplitude=TEMP_DAILY_AMPLITUDE,
                          temp_seasonal_amplitude=TEMP_SEASONAL_AMPLITUDE, steps_per_day=STEPS_PER_DAY,
                          steps_per_year=STEPS_PER_YEAR, temp_daily_offset=TEMP_DAILY_OFFSET,
                          temp_seasonal_offset=TEMP_SEASONAL_OFFSET):
    """Calculates temperature based on daily and seasonal cycles."""
    # Daily cycle (sinusoidal, peaking in afternoon)
    step_of_day = step % steps_per_day
    daily_temp = temp_daily_amplitude * jnp.sin((2 * jnp.pi / steps_per_day) * (step_of_day - temp_daily_offset))

    # Seasonal cycle (sinusoidal, peaking in summer)
    day_of_year_step = step % steps_per_year
    seasonal_temp = temp_seasonal_amplitude * jnp.sin((2 * jnp.pi / steps_per_year) * (day_of_year_step - temp_seasonal_offset))

    # Combine base temperature, daily, and seasonal cycles
    current_temp = base_temp + daily_temp + seasonal_temp
    return current_temp


# --- Simulation Parameters for Step Logic ---
# Damage rate parameters for converting canonical bases to damaged bases based on UV
UV_DAMAGE_RATE_G_TO_8OXOG = 0.001 # Probability per G per UV unit per step to become 8-oxoG
UV_DAMAGE_RATE_PYRIMIDINE_DIMER = 0.002 # Probability per adjacent pyrimidine pair per UV unit per step to become a dimer (simplified)

# Repair rate parameters for converting damaged bases back to canonical bases
REPAIR_RATE_8OXOG = 0.01 # Probability per 8-oxoG per step to be repaired
REPAIR_RATE_PYRIMIDINE_DIMER = 0.008 # Probability per dimer per step to be repaired

# Error-prone repair rate parameters
ERROR_PRONE_RATE_8OXOG = 0.1 # Probability that 8-oxoG repair is error-prone leading to mutation
ERROR_PRONE_RATE_PYRIMIDINE_DIMER = 0.15 # Probability that dimer repair is error-prone leading to mutation


# Strategy-dependent parameters
# Cautious strategy (0) has higher repair rate and lower damage threshold for arrest
CAUTIOUS_REPAIR_MULTIPLIER = 1.5 # Cautious agents repair 1.5x faster
RECKLESS_REPAIR_MULTIPLIER = 1.0 # Reckless agents repair at base rate
CAUTIOUS_ARREST_THRESHOLD_MULTIPLIER = 0.8 # Cautious agents arrest at 80% of base threshold
RECKLESS_ARREST_THRESHOLD_MULTIPLIER = 1.0 # Reckless agents arrest at base threshold

# Stop codon definitions (UAA, UAG, UGA) as integer sequences
STOP_CODON_UAA = jnp.array([BASE_U, BASE_A, BASE_A], dtype=jnp.int8) # 1, 0, 0
STOP_CODON_UAG = jnp.array([BASE_U, BASE_A, BASE_G], dtype=jnp.int8) # 1, 0, 2
STOP_CODON_UGA = jnp.array([BASE_U, BASE_G, BASE_A], dtype=jnp.int8) # 1, 2, 0

# Protoribosome state transition parameters
DAMAGE_BASE_COUNT_ARREST_THRESHOLD = 5 # Example: arrest if more than 5 damaged bases
ARREST_DURATION_STEPS = 10 # Number of steps an arrested protoribosome stays arrested before potentially becoming inactive (simplified)
REINITIATION_RATE = 0.01 # Probability per step for inactive protoribosomes to become active if conditions allow

# Autocatalysis (Reproduction/Failure) Parameters
# Probability of successful autocatalysis per step for agents with ACTIVE protoribosomes
AUTOCATALYSIS_SUCCESS_RATE_ACTIVE = 0.05
# Probability of failure per step for agents with INACTIVE protoribosomes
AUTOCATALYSIS_FAILURE_RATE_INACTIVE = 0.1
# Probability of failure per step for agents with ARRESTED protoribosomes
AUTOCATALYSIS_FAILURE_RATE_ARRESTED = 0.05
# Additional failure probability per step for agents with very high damage
AUTOCATALYSIS_FAILURE_RATE_HIGH_DAMAGE = 0.02
# Damage threshold for triggering additional failure probability (based on average damage per base)
HIGH_DAMAGE_FAILURE_THRESHOLD = 0.5 # Example threshold based on average damage per base

# Foraging/Removal Parameters
FORAGING_RATE_FAILED = 0.001 # Probability per step for a failed agent to be removed
FAILURE_DURATION_FORAGING_THRESHOLD = 50 # Number of consecutive failure steps before foraging risk increases (simplified)


# Main simulation loop
NUM_STEPS = 10000
print(f"Running simulation from step {start_step} to {NUM_STEPS}")

# Ensure population_state is not None before starting the loop
if population_state is not None:
    # Initialize random key for stochastic processes within the loop
    # Use a key derived from the step number for reproducibility across steps
    key = jax.random.PRNGKey(start_step)

    # Ensure timers are in population_state if resuming from an older checkpoint structure
    if 'protoribosome_arrest_timer' not in population_state:
         population_state['protoribosome_arrest_timer'] = jnp.zeros(NUM_AGENTS, dtype=jnp.int32)
    if 'autocatalysis_failure_timer' not in population_state:
         population_state['autocatalysis_failure_timer'] = jnp.zeros(NUM_AGENTS, dtype=jnp.int32)


    for step in range(start_step, NUM_STEPS + 1):
        # Initialize num_error_prone_mutations for this step
        num_error_prone_mutations = 0

        # Update the random key for the current step
        key, step_key = jax.random.split(key)

        # Calculate environmental variables for the current step
        current_uv_radiation = calculate_uv(step)
        current_temperature = calculate_temperature(step)

        # --- 1. Environmental Factors & RNA Damage Application (Specific Damaged Bases) ---
        # Apply UV damage to RNA sequences for all agents
        current_rna_sequence = population_state['rna_sequence']

        # Identify canonical G bases
        canonical_g_mask = (current_rna_sequence == BASE_G)

        # Identify pyrimidine bases (U or C)
        pyrimidine_mask = ((current_rna_sequence == BASE_U) | (current_rna_sequence == BASE_C))

        # Identify adjacent pyrimidine pairs (simplified: check if a pyrimidine is next to another pyrimidine)
        pyrimidine_mask_shifted = jnp.roll(pyrimidine_mask, -1, axis=1)
        pyrimidine_mask_shifted = jnp.where(jnp.arange(SEQUENCE_LENGTH) == SEQUENCE_LENGTH - 1, False, pyrimidine_mask_shifted)
        potential_dimer_site_mask = pyrimidine_mask & pyrimidine_mask_shifted

        # Determine which canonical Gs become 8-oxoG due to UV
        key, uv_g_key = jax.random.split(step_key)
        uv_g_roll = jax.random.uniform(uv_g_key, shape=canonical_g_mask.shape)
        g_to_8oxog_mask = canonical_g_mask & (uv_g_roll < current_uv_radiation * UV_DAMAGE_RATE_G_TO_8OXOG)

        # Determine which potential pyrimidine dimer sites become dimers due to UV
        key, uv_pyr_key = jax.random.split(key)
        uv_pyr_roll = jax.random.uniform(uv_pyr_key, shape=potential_dimer_site_mask.shape)
        to_pyrimidine_dimer_mask = potential_dimer_site_mask & (uv_pyr_roll < current_uv_radiation * UV_DAMAGE_RATE_PYRIMIDINE_DIMER)

        # Update RNA sequence: canonical bases become damaged bases
        new_rna_sequence = jnp.copy(current_rna_sequence)
        new_rna_sequence = jnp.where(g_to_8oxog_mask, jnp.array(DAMAGED_G, dtype=jnp.int8), new_rna_sequence)
        new_rna_sequence = jnp.where(to_pyrimidine_dimer_mask, jnp.array(DAMAGED_PYRIMIDINE_DIMER, dtype=jnp.int8), new_rna_sequence)

        population_state['rna_sequence'] = new_rna_sequence

        # --- 2. RNA Repair (Targeting Damaged Bases) ---
        # Identify damaged bases
        damaged_g_mask = (population_state['rna_sequence'] == DAMAGED_G)
        damaged_dimer_mask = (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER)

        # Determine repair probability based on strategy
        repair_rate_multiplier = jnp.where(
            population_state['strategy'] == 0, # Cautious strategy
            jnp.array(CAUTIOUS_REPAIR_MULTIPLIER),
            jnp.array(RECKLESS_REPAIR_MULTIPLIER) # Reckless strategy
        )[:, None] # Expand dims for broadcasting

        # Determine which damaged Gs and dimers are repaired
        key, repair_roll_key = jax.random.split(key) # Generate repair roll for this step
        repair_roll = jax.random.uniform(repair_roll_key, shape=damaged_g_mask.shape) # Use shape of damage mask

        g_repair_mask = damaged_g_mask & (repair_roll < REPAIR_RATE_8OXOG * repair_rate_multiplier) # Apply strategy multiplier
        dimer_repair_mask = damaged_dimer_mask & (repair_roll < REPAIR_RATE_PYRIMIDINE_DIMER * repair_rate_multiplier) # Apply strategy multiplier


        # --- 3. Implement Mutations (Error-Prone Repair of Damaged Bases) ---
        # Mutations occur when repair is error-prone.
        # Determine which repaired damaged Gs result in error-prone repair
        key, error_prone_g_key = jax.random.split(key)
        error_prone_g_roll = jax.random.uniform(error_prone_g_key, shape=g_repair_mask.shape)
        error_prone_g_repair_mask = g_repair_mask & (error_prone_g_roll < ERROR_PRONE_RATE_8OXOG)

        # Determine which repaired dimers result in error-prone repair
        key, error_prone_dimer_key = jax.random.split(key)
        error_prone_dimer_roll = jax.random.uniform(error_prone_dimer_key, shape=dimer_repair_mask.shape)
        error_prone_dimer_repair_mask = dimer_repair_mask & (error_prone_dimer_roll < ERROR_PRONE_RATE_PYRIMIDINE_DIMER)

        # Combine masks for all sites undergoing error-prone repair
        error_prone_repair_mask = error_prone_g_repair_mask | error_prone_dimer_repair_mask

        # At sites undergoing error-prone repair, the base changes to a random canonical base (0-3).
        num_error_prone_mutations = jnp.sum(error_prone_repair_mask) # Calculate before the if block

        if num_error_prone_mutations > 0:
             key, mutation_type_key = jax.random.split(key)
             # Generate random new canonical bases (0-3) for the mutation sites
             random_new_bases = jax.random.randint(mutation_type_key, shape=(num_error_prone_mutations,), minval=0, maxval=4)

             # Get the indices where mutations occur.
             mutation_indices = jnp.argwhere(error_prone_repair_mask, size=num_error_prone_mutations, fill_value=-1) # Get (agent_idx, base_idx)

             # Update the rna_sequence at the error-prone repair sites with random canonical bases
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 mutation_indices,
                 random_new_bases,
                 scatter_dims=(0, 1) # Scatter along agent and sequence dimensions
             )

        # For damaged bases that were repaired accurately (not error-prone), they revert to their canonical form.
        # Accurate repair of 8-oxoG (8) goes back to G (2).
        # Accurate repair of dimer (99) goes back to a random pyrimidine (U=1 or C=3) for simplicity.

        accurate_g_repair_mask = g_repair_mask & ~error_prone_g_repair_mask
        accurate_dimer_repair_mask = dimer_repair_mask & ~error_prone_dimer_repair_mask

        # Apply accurate G repair: 8-oxoG (8) becomes G (2)
        population_state['rna_sequence'] = jnp.where(
            accurate_g_repair_mask,
            jnp.array(BASE_G, dtype=jnp.int8),
            population_state['rna_sequence']
        )

        # Apply accurate dimer repair: Dimer (99) becomes a random pyrimidine (U=1 or C=3)
        num_accurate_dimer_repairs = jnp.sum(accurate_dimer_repair_mask)
        if num_accurate_dimer_repairs > 0:
             key, accurate_dimer_base_key = jax.random.split(key)
             # Generate random pyrimidine bases (1 or 3)
             random_pyr_bases = jax.random.randint(accurate_dimer_base_key, shape=(num_accurate_dimer_repairs,), minval=0, maxval=2) # 0 or 1
             # Map 0 to U (1) and 1 to C (3)
             random_canonical_pyr = jnp.where(random_pyr_bases == 0, jnp.array(BASE_U, dtype=jnp.int8), jnp.array(BASE_C, dtype=jnp.int8))

             # Get the indices where accurate dimer repairs occur.
             accurate_dimer_indices = jnp.argwhere(accurate_dimer_repair_mask, size=num_accurate_dimer_repairs, fill_value=-1) # Get (agent_idx, base_idx)

             # Update the rna_sequence at the accurate dimer repair sites
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 accurate_dimer_indices,
                 random_canonical_pyr,
                 scatter_dims=(0, 1) # Scatter along agent and sequence dimensions
             )


        # --- 4. Determine Stop Codons ---
        # Check for the presence of stop codons in the RNA sequence for each agent.
        # This needs to be done after mutations, as mutations can create or remove stop codons.
        # Damaged bases (8 and 99) are not canonical bases and would likely block translation if encountered by the protoribosome.
        # For stop codon check, we should only look for canonical bases.

        # Reset has_stop_codon for all agents before checking
        population_state['has_stop_codon'] = jnp.array(False, shape=NUM_AGENTS)

        # Check for each stop codon in each agent's RNA sequence
        # This part is not easily jittable due to variable inner loop length (sequence scanning)
        # For clarity and direct implementation of the concept, let's keep this for loop over agents for now.

        has_stop_codon_mask = jnp.zeros(NUM_AGENTS, dtype=bool)

        if SEQUENCE_LENGTH >= 3:
            for i in range(NUM_AGENTS):
                 rna = population_state['rna_sequence'][i]
                 is_stop = False
                 # Iterate through possible codon start positions
                 for j in range(SEQUENCE_LENGTH - 2):
                     codon = rna[j : j + 3]
                     # Check if all bases in the codon are canonical (0-3)
                     if jnp.all((codon >= 0) & (codon <= 3)):
                          # Check for stop codons
                         if jnp.array_equal(codon, STOP_CODON_UAA) or \
                            jnp.array_equal(codon, STOP_CODON_UAG) or \
                            jnp.array_equal(codon, STOP_CODON_UGA):
                             is_stop = True
                             break # Found a stop codon, no need to check further for this agent
                     # If any base in the codon is a damaged base (8 or 99), it's not a valid stop codon here
                     # and would likely cause protoribosome arrest (handled in the next step).

                 has_stop_codon_mask = has_stop_codon_mask.at[i].set(is_stop)

        population_state['has_stop_codon'] = has_stop_codon_mask


        # --- 5. Implement Protoribosome State Updates ---
        # Update protoribosome states based on current state, damage, and stop codons.
        current_proto_state = population_state['protoribosome_state']
        new_proto_state = jnp.copy(current_proto_state) # Start with current states

        # Calculate number of damaged bases per agent
        num_damaged_bases_per_agent = jnp.sum(
            (population_state['rna_sequence'] == DAMAGED_G) | (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER),
            axis=1
        )

        # Determine damage arrest threshold based on strategy
        damage_arrest_threshold_per_agent = jnp.where(
            population_state['strategy'] == 0, # Cautious strategy
            jnp.array(DAMAGE_BASE_COUNT_ARREST_THRESHOLD * CAUTIOUS_ARREST_THRESHOLD_MULTIPLIER),
            jnp.array(DAMAGE_BASE_COUNT_ARREST_THRESHOLD * RECKLESS_ARREST_THRESHOLD_MULTIPLIER) # Reckless strategy
        )


        # Transition rules:
        # ACTIVE (0) -> ARRESTED (1): if has_stop_codon is True OR num_damaged_bases_per_agent > strategy-dependent threshold
        # ARRESTED (1) -> INACTIVE (2) : if protoribosome_arrest_timer exceeds ARREST_DURATION_STEPS
        # ARRESTED (1) -> ACTIVE (0): if damaged bases are repaired AND stop codon removed (simplified: if no stop codon AND num_damaged_bases_per_agent <= strategy-dependent threshold)
        # INACTIVE (2) -> ACTIVE (0): (Optional, re-initiation) after a period of inactivity (e.g., if timer exceeds a threshold) AND conditions for arrest are no longer met

        # Update protoribosome arrest timer for agents currently in ARRESTED state
        arrested_mask = (current_proto_state == PROTO_STATE_ARRESTED)
        protoribosome_arrest_timer = jnp.where(
            arrested_mask,
            population_state['protoribosome_arrest_timer'] + 1, # Use protoribosome_arrest_timer from population_state
            jnp.array(0, dtype=jnp.int32) # Timer resets for other states
        )

        # Transition from ACTIVE to ARRESTED
        active_to_arrested_mask = (current_proto_state == PROTO_STATE_ACTIVE) & \
                                   (population_state['has_stop_codon'] | (num_damaged_bases_per_agent > damage_arrest_threshold_per_agent))

        new_proto_state = jnp.where(
            active_to_arrested_mask,
            jnp.array(PROTO_STATE_ARRESTED, dtype=jnp.int8), # Ensure dtype matches population_state
            new_proto_state
        )


        # Transition from ARRESTED to INACTIVE
        # This transition is based on the arrest timer
        arrested_to_inactive_mask = (current_proto_state == PROTO_STATE_ARRESTED) & \
                                      (protoribosome_arrest_timer >= jnp.array(ARREST_DURATION_STEPS))

        new_proto_state = jnp.where(
            arrested_to_inactive_mask,
            jnp.array(PROTO_STATE_INACTIVE, dtype=jnp.int8), # Ensure dtype matches population_state
            new_proto_state
        )

        # Transition from ARRESTED back to ACTIVE (Recovery)
        # Conditions: No stop codon AND damage is below the strategy-dependent threshold
        arrested_to_active_mask = (current_proto_state == PROTO_STATE_ARRESTED) & \
                                  (~population_state['has_stop_codon'] & (num_damaged_bases_per_agent <= damage_arrest_threshold_per_agent))

        new_proto_state = jnp.where(
            arrested_to_active_mask,
            jnp.array(PROTO_STATE_ACTIVE, dtype=jnp.int8),
            new_proto_state
        )
        # Reset timer for agents transitioning from ARRESTED to ACTIVE
        protoribosome_arrest_timer = jnp.where(
            arrested_to_active_mask,
            jnp.array(0, dtype=jnp.int32),
            protoribosome_arrest_timer
        )


        # Transition from INACTIVE back to ACTIVE (Re-initiation)
        # Let's add a simple re-initiation probability for inactive protoribosomes
        # if conditions for arrest are not met.
        inactive_to_active_mask_conditions = (current_proto_state == PROTO_STATE_INACTIVE) & \
                                             (~population_state['has_stop_codon'] & (num_damaged_bases_per_agent <= damage_arrest_threshold_per_agent))

        # Add a re-initiation probability
        key, reinit_roll_key = jax.random.split(key)
        reinit_roll = jax.random.uniform(reinit_roll_key, shape=inactive_to_active_mask_conditions.shape)
        inactive_to_active_mask = inactive_to_active_mask_conditions & (reinit_roll < REINITIATION_RATE)


        new_proto_state = jnp.where(
            inactive_to_active_mask,
            jnp.array(PROTO_STATE_ACTIVE, dtype=jnp.int8),
            new_proto_state
        )
        # Reset timer for agents transitioning from INACTIVE to ACTIVE (timer is already 0)


        # Update the protoribosome_state in the population_state
        population_state['protoribosome_state'] = new_proto_state
        # Update the arrest timer in the population_state
        population_state['protoribosome_arrest_timer'] = protoribosome_arrest_timer


        # --- 7. Implement Autocatalysis (Success/Failure and Population Dynamics) ---
        # Determine which agents are capable of successful autocatalysis (reproduction)
        # Only agents with ACTIVE protoribosomes can potentially reproduce
        can_reproduce_mask = (population_state['protoribosome_state'] == PROTO_STATE_ACTIVE)
        key, success_roll_key = jax.random.split(key)
        success_roll = jax.random.uniform(success_roll_key, shape=can_reproduce_mask.shape)
        successfully_reproducing_agents_mask = can_reproduce_mask & (success_roll < AUTOCATALYSIS_SUCCESS_RATE_ACTIVE)

        num_successfully_reproducing_agents = jnp.sum(successfully_reproducing_agents_mask)

        # Update autocatalysis failure timer: increment for agents that did NOT reproduce, reset for those that did.
        autocatalysis_failure_timer = jnp.where(
            successfully_reproducing_agents_mask,
            jnp.array(0, dtype=jnp.int32), # Reset timer on successful reproduction
            population_state['autocatalysis_failure_timer'] + 1 # Increment timer on failure to reproduce
        )
        population_state['autocatalysis_failure_timer'] = autocatalysis_failure_timer


        # --- 8. Implement Foraging/Removal ---
        # Agents can be removed from the population (foraged) if they fail autocatalysis
        # consistently for a certain period, or based on other criteria.
        # Use the autocatalysis_failure_timer to determine foraging risk.

        # Identify agents at risk of foraging (failed for a certain duration)
        foraging_risk_mask = (population_state['autocatalysis_failure_timer'] >= jnp.array(FAILURE_DURATION_FORAGING_THRESHOLD))

        # Determine which agents are foraged
        key, foraging_roll_key = jax.random.split(key)
        foraging_roll = jax.random.uniform(foraging_roll_key, shape=foraging_risk_mask.shape)
        foraged_agents_mask = foraging_risk_mask & (foraging_roll < FORAGING_RATE_FAILED)

        num_foraged_agents = jnp.sum(foraged_agents_mask)

        # Implement Population Dynamics: Replace foraged agents with offspring from successfully reproducing agents
        # This maintains a constant population size (NUM_AGENTS).
        num_replacements_needed = num_foraged_agents
        num_potential_offspring = num_successfully_reproducing_agents

        # If there are reproducing agents and replacements needed, select reproducing agents
        if num_potential_offspring > 0 and num_replacements_needed > 0:
             key, offspring_selection_key = jax.random.split(key)
             # Select indices of successfully reproducing agents (potential parents)
             reproducing_indices = jnp.where(successfully_reproducing_agents_mask)[0]

             # Randomly select parents to produce offspring to replace foraged agents
             # Use replacement=True if num_replacements_needed > num_potential_offspring
             replace_with_replacement = num_replacements_needed > num_potential_offspring
             parent_indices_for_offspring = jax.random.choice(
                 offspring_selection_key,
                 reproducing_indices,
                 shape=(num_replacements_needed,),
                 replace=replace_with_replacement
             )

             # Get the state of the selected parent agents
             parent_rna_sequences = population_state['rna_sequence'][parent_indices_for_offspring]
             parent_strategies = population_state['strategy'][parent_indices_for_offspring]

             # Create offspring state (e.g., copy parent's RNA and strategy, reset damage/protoribosome state, reset timers)
             offspring_rna_sequences = jnp.copy(parent_rna_sequences) # Offspring inherit RNA
             offspring_strategies = jnp.copy(parent_strategies) # Offspring inherit strategy
             offspring_rna_damage = jnp.zeros_like(offspring_rna_sequences, dtype=jnp.float32) # Offspring start with no general damage
             # Important: Stop codons need to be re-checked for offspring based on inherited sequence in the next step
             offspring_has_stop_codon = jnp.zeros(num_replacements_needed, dtype=bool)
             offspring_protoribosome_state = jnp.full(num_replacements_needed, PROTO_STATE_ACTIVE, dtype=jnp.int8) # Offspring protoribosomes are active
             offspring_protoribosome_arrest_timer = jnp.zeros(num_replacements_needed, dtype=jnp.int32) # Offspring timer is 0
             offspring_autocatalysis_failure_timer = jnp.zeros(num_replacements_needed, dtype=jnp.int32) # Offspring timer is 0


             # Select indices of foraged agents to be replaced
             foraged_indices = jnp.where(foraged_agents_mask)[0]

             # Update the state of the foraged agents with the offspring state
             # Use jax.lax.scatter_nd_update to update multiple indices at once
             population_state['strategy'] = jax.lax.scatter_nd_update(
                 population_state['strategy'],
                 foraged_indices[:, None], # Indices need to be shape (N, 1) for scatter
                 offspring_strategies,
                 scatter_dims=(0,) # Scatter along the first dimension (agents)
             )
             population_state['rna_damage'] = jax.lax.scatter_nd_update(
                 population_state['rna_damage'],
                 foraged_indices[:, None, None], # Indices need to be shape (N, 1, 1) for scatter into 2D array
                 offspring_rna_damage,
                 scatter_dims=(0,)
             )
             population_state['has_stop_codon'] = jax.lax.scatter_nd_update(
                 population_state['has_stop_codon'],
                 foraged_indices[:, None],
                 offspring_has_stop_codon,
                 scatter_dims=(0,)
             )
             population_state['rna_sequence'] = jax.lax.scatter_nd_update(
                 population_state['rna_sequence'],
                 foraged_indices[:, None, None], # Indices need to be shape (N, 1, 1) for scatter into 2D array
                 offspring_rna_sequences,
                 scatter_dims=(0,)
             )
             population_state['protoribosome_state'] = jax.lax.scatter_nd_update(
                 population_state['protoribosome_state'],
                 foraged_indices[:, None],
                 offspring_protoribosome_state,
                 scatter_dims=(0,)
             )
             population_state['protoribosome_arrest_timer'] = jax.lax.scatter_nd_update(
                 population_state['protoribosome_arrest_timer'],
                 foraged_indices[:, None],
                 offspring_protoribosome_arrest_timer,
                 scatter_dims=(0,)
             )
             population_state['autocatalysis_failure_timer'] = jax.lax.scatter_nd_update(
                 population_state['autocatalysis_failure_timer'],
                 foraged_indices[:, None],
                 offspring_autocatalysis_failure_timer,
                 scatter_dims=(0,)
             )

        # If no reproducing agents but foraging occurs, these spots are effectively removed (status=0 implicitly, or just replaced with empty/dummy state)
        # In this replacement model, if num_potential_offspring is 0, num_replacements_needed will also be 0 as no offspring can be created.
        # If num_successfully_reproducing_agents < num_foraged_agents, we replace as many as possible from reproducing agents.
        # The remaining foraged spots are effectively empty until future steps produce more offspring than failures.


        # --- Placeholder for other simulation step logic ---
        # (All core simulation step logic is now implemented)


        # Placeholder print statement
        if step % 1000 == 0 or step == start_step or step == NUM_STEPS:
            # Calculate average damaged bases per agent for printing
            num_damaged_bases_per_agent = jnp.sum(
                (population_state['rna_sequence'] == DAMAGED_G) | (population_state['rna_sequence'] == DAMAGED_PYRIMIDINE_DIMER),
                axis=1
            )
            avg_damaged_bases_for_print = jnp.mean(num_damaged_bases_per_agent) if num_damaged_bases_per_agent.shape[0] > 0 else 0.0

            num_agents_with_stop_codon = jnp.sum(population_state['has_stop_codon']) # Count agents with stop codons
            # Count protoribosome states for all agents
            active_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_ACTIVE)
            arrested_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_ARRESTED)
            inactive_protos_count = jnp.sum(population_state['protoribosome_state'] == PROTO_STATE_INACTIVE)

            # Count agents by strategy
            cautious_count = jnp.sum(population_state['strategy'] == 0)
            reckless_count = jnp.sum(population_state['strategy'] == 1)

            # Total agents remain constant at NUM_AGENTS in this replacement model


            print(f"Simulating step {step}/{NUM_STEPS}, UV: {current_uv_radiation:.2f}, Temp: {current_temperature:.0f}째C, Avg Damaged Bases: {avg_damaged_bases_for_print:.2f}, Error-Prone Mutations: {num_error_prone_mutations}, Agents with Stop Codon: {num_agents_with_stop_codon}")
            print(f"  Protoribosome States: Active={active_protos_count}, Arrested={arrested_protos_count}, Inactive={inactive_protos_count}")
            print(f"  Agents by Strategy: Cautious={cautious_count}, Reckless={reckless_count}")
            print(f"  Autocatalysis: Successful={num_successfully_reproducing_agents}, Failed (this step): {NUM_AGENTS - num_successfully_reproducing_agents}, Foraged: {num_foraged_agents}")


        # Add checkpointing logic
        # Only save if population_state is defined
        if population_state is not None and step > 0 and (step % CHECKPOINT_INTERVAL == 0 or step == NUM_STEPS):
            if step == NUM_STEPS:
                 checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'final_state_step_{step}.pkl')
            else:
                 checkpoint_filename = os.path.join(CHECKPOINT_DIR, f'checkpoint_step_{step}.pkl')

            try:
                with open(checkpoint_filename, 'wb') as f:
                    pickle.dump(population_state, f)
                print(f"Checkpoint saved at step {step} to {checkpoint_filename}")
            except Exception as e:
                print(f"Error saving checkpoint at step {step}: {e}")

else:
    print("Simulation loop skipped because population_state is not defined.")


print("Simulation loop finished.")

Ensured checkpoint directory exists: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50
Successfully loaded last checkpoint: /content/drive/MyDrive/Colab Notebooks/qtpu/replica_1/Expt_2025-7-6_08-50/initial_state.pkl (Step: 0)
Resuming simulation from step 1 with loaded state.
Running simulation from step 1 to 10000


AttributeError: module 'jax.lax' has no attribute 'scatter_nd_update'