In [None]:
import uuid

import astrocast.sim as sim
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import importlib
from IPython.display import clear_output
import logging
import time
from pathlib import Path
import shutil
import humanize

In [None]:
def generate_signal_function(N: int, burst_center: int, burst_width: float) -> callable:
    """
    Generates a signal function with a Gaussian burst.

    The function is N timepoints long, and each timepoint is an integer. The output is a float between 0 and 1.
    The signal starts and ends with 0 and has one smooth burst whose width and slope can be controlled by parameters.

    Args:
      N: The length of the signal in timepoints.
      burst_center: The center of the Gaussian burst (controls the position of the burst).
      burst_width: The width of the Gaussian burst (controls the spread of the burst).

    Returns:
      A function that takes a timepoint x and returns a value y, looping if x > N.
    """
    # Create a time series
    t = np.linspace(0, N - 1, N)
    
    # Normalize the Gaussian function so that its maximum is 1
    gaussian = np.exp(-((t - burst_center) ** 2) / (2 * burst_width ** 2))
    normalized_gaussian = gaussian / np.max(gaussian)
    
    # Define the signal function
    def signal_function(x: int) -> float:
        # Ensure the function loops
        x_mod = x % N
        return normalized_gaussian[x_mod]
    
    return signal_function


N = 24  # Total timepoints
burst_center = int(N / 2)  # Center of the burst
burst_width = 0.5  # Width of the burst
signal_func = generate_signal_function(N, burst_center, burst_width)

In [None]:
x = 3
y = 5
px = np.array([[x], [y]])
px.shape

In [None]:
importlib.reload(sim)

dl = sim.DataLogger(save_path=Path("exp"), save_checkpoint_interval=None, overwrite=True)
messenger_param = dict(print_messages=False, log_path=Path("log.txt"), save_log_every=1)
vis_param = dict(dpi=200, display_interval=5, save_interval=5,
                 override=True, img_folder=Path("./imgs"))

# todo: external calcium concentration
# todo: thickness of interacting pixels
# todo: unify steps and dt
env_dict = dict(molecules=["glutamate"], diffusion_rate=0.01, degradation_factor=0.1, dt=25)

grm_dict = dict(num_dendrites=16,  #32,
                z_thickness=3, jitter=2, release_amplitude=500,
                stochastic_probability=0.1, signal_function=lambda x: 0,
                )

ast_dict = dict(num_branches=4,
                radius=3, max_branch_radius=10, min_radius=0.001,
                start_spawn_radius=0.05, spawn_length=3, spawn_radius_factor=0.5,
                growth_factor=0.1, max_tries=5,
                glutamate_uptake_rate=0.1,
                glu_v_max=10, glu_k_m=500,
                allow_pruning=True,
                atp_cost_per_glutamate=-18, atp_cost_per_unit_surface=1,
                # molecules=dict(glutamate=0, calcium=0, ATP=1e-6),
                trend_history=100,
                min_steepness=0.005, min_trend_amplitude=0.01,
                repellent_volume_factor=0, repellent_surface_factor=0,
                repellent_concentration=0  # 1
                )

my_sim = sim.Simulation(num_astrocytes=16,  #24,
                        grid_size=(100, 100), border=5,
                        environment_param=env_dict, glutamate_release_param=grm_dict,
                        astrocyte_param=ast_dict, vis_param=vis_param, messenger_param=messenger_param,
                        data_logger=dl)



In [None]:
plot_param = {
    'C': dict(line_thickness=(0.2, 1.5)),
    'D': dict(blur_sigma=0.3, blur_radius=1),
    'E': dict(blur_sigma=0.3, blur_radius=1),
    'F': dict(blur_sigma=0.3, blur_radius=1),
    }

my_sim.run_simulation_step(100, plot_param=plot_param)

In [None]:
my_sim.vis.plot_astrocyte_by_line(line_thickness=(0.2, 1.5))
# my_sim.vis.plot_astrocyte_by_grid(molecule="glutamate")

# my_sim.vis.plot_astrocyte_by_grid(molecule="glutamate", blur_sigma=0.3, blur_radius=1)


In [None]:
import seaborn as sns

vis = my_sim.vis
bodies, branches = vis.get_bodies_and_branches()
display(len(branches))
thickness = [branch.end.radius for branch in branches]
thickness = [thick for thick in thickness if thick is not None]

sns.displot([np.log10(thick) for thick in thickness])

thickness = np.array(thickness)
display(np.min(thickness), np.max(thickness))


In [None]:
my_sim.message_logger.get_messages()

In [None]:
importlib.reload(prep)

io = prep.IO()
arr = io.load(Path("imgs/"), lazy=False)
io.save(f"export_{uuid.uuid1().hex}.avi", data=arr, overwrite=True)

In [None]:
# my_sim.plot(last_n_messages=15)
# display(my_sim.fig)
# my_sim.fig.savefig(img_dir.joinpath(f"img_{step_counter}.png"), dpi=(80))
# 
# my_sim.data_logger.save(Path("log.txt"))

In [None]:
my_sim.astrocytes[0].plot_branch_history()


In [None]:
branches = my_sim.astrocytes[2].branches

thickness = []
for branch in branches:
    avg_radius = np.mean(branch.start.radius + branch.end.radius)
    thickness.append(avg_radius)

print(f"{np.mean(thickness):.2f} +- {np.std(thickness):.2f}")

In [None]:
thickness

In [None]:
# fig, ax = plt.subplots(1, 1)
# ax.set_xscale('log')
# ax.set_yscale('log')
# _ = ax.hist(thickness, bins=1000)

In [None]:
branch0 = None
branch_id = "92ede384"

branches = []
for ast in my_sim.astrocytes:
    branches += ast.branches

for branch in branches:
    if branch.get_short_id() == branch_id:
        branch0 = branch
        break

print(branch0)

if branch0 is not None:
    
    trend = branch0.get_trend("glutamate", intra=False)
    print(f"{trend:.1E}, {trend / ast_dict['min_trend_amplitude'] * 100:.1f}%")
    
    fig, ax = plt.subplots()
    
    hist = branch0.extracellular_history["glutamate"]
    ax.plot(hist)


In [None]:
history = my_sim.environment_grid.history["glutamate"]

fig, ax = plt.subplots(1, 1)
ax.plot(history)

In [None]:



# Example usage:
N = 24  # Total timepoints
burst_center = int(N / 2)  # Center of the burst
burst_width = 0.5  # Width of the burst
signal_func = generate_signal_function(N, burst_center, burst_width)

# Getting values from the signal function
x_len = N * 10
y_values = [signal_func(x) for x in range(x_len)]  # Example for 150 timepoints to show the looping

fig, ax = plt.subplots()
ax.plot(range(x_len), y_values)

In [None]:
branches = []
for ast in my_sim.astrocytes:
    branches += ast.branches

glu_conc = []
glu_amount = []
for branch in branches:
    glu_conc.append(branch.get_concentration("glutamate"))
    glu_amount.append(branch.get_amount("glutamate"))

print(f"Conc: {np.mean(glu_conc):.2E} +- {np.std(glu_conc):.2E}")
print(f"Amount: {np.mean(glu_amount):.2E} +- {np.std(glu_amount):.2E}")

In [None]:
glu_conc = my_sim.environment_grid.shared_arrays["glutamate"][0]
print(f"Conc: {np.mean(glu_conc):.2f} +- {np.std(glu_conc):.2f}")