# spatial_CR.ipynb 
Notebook to visualize raw data from spatial simulations

In [2]:
# import of necessary files from other folders

import os
import sys
import ast

# libraries imports

import importlib
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import seaborn as sns
import pandas as pd
import pickle

from IPython.display import Image, display
from PIL import Image
from pathlib import Path
from scipy.integrate import solve_ivp
from matplotlib.colors import LinearSegmentedColormap,TwoSlopeNorm


# Set default figure and axes background color to white
plt.rcParams['figure.facecolor'] = 'white'  # Background for the entire figure
plt.rcParams['axes.facecolor'] = 'white' 

In [3]:
# Function to check if a string can be converted to a float or an integer
def is_number(value):
    try:
        float(value)
        return True
    except ValueError:
        return False

# Import complete data into a df

In [None]:
# Path to the results directory
results_directory_path = Path('insert_base_path/CRIMMES/MODEL/Data/spatial_CR_results')

# Definition of DataFrame columns
columns = ['met_sparsity','up_sparsity','replica','avg_consumed','param_dict','C','D','N_cr','R_cr','N_lv','A','g','R_sp','grid_sp','N_sp','sh_sp']
df_complete = pd.DataFrame(columns=columns)

# List to accumulate the data
rows = []

# Get all folders and sort them alphabetically
folders = sorted(results_directory_path.iterdir(), key=lambda x: x.name)

# Loop over all items in the sorted list of folders
for folder in folders:
    if folder.is_dir():

        parameters = folder.name
        
        # Check and convert only the parts that can be numbers (floats or integers)
        parameters_list = [float(value) if '.' in value else int(value) for value in parameters.split('_') if is_number(value)]
        
        # Construct the full path of the pickle file
        pickle_file_path = folder / 'all_data.pkl'
        
        if pickle_file_path.exists():
            # Load data from the pickle file
            with open(pickle_file_path, 'rb') as file:
                data = pickle.load(file)

            n_resources = data['met_sparsity'],
            avg_consumed = data['average_consumed'],
            parameters = data['parameters'],
            replica = data['replica'],
            sparsity = data['up_sparsity'],
            C = data['uptake'],
            D = data['D'],
            R_cr = data['CR_R'], 
            N_cr = np.array(data['CR_N']), 
            N_lv = np.array(data['LV']),
            g = data['g0'],
            A = data['A'],
            R_sp = data['current_R'],
            grid_sp = data['current_N'],
            N_sp = data['abundances'],
            sh_sp = data['s_list']
        
            # Create a new row and add it to the rows list
            new_row = parameters_list + [avg_consumed]+[parameters]+[C]+[D]+[N_cr]+[R_cr]+[N_lv]+[A]+[g]+[R_sp]+[grid_sp]+[N_sp]+[sh_sp]
            rows.append(new_row)

# Create a new DataFrame with the accumulated rows
df_complete = pd.concat([df_complete, pd.DataFrame(rows, columns=columns)], ignore_index=True)

In [4]:
print(len(df_complete))
df_complete.head()

NameError: name 'df_complete' is not defined

## Visualize some time series

In [None]:
if len(df_complete)>20:
    # choose 100 random to plot compared
    import random
    indexes_random = random.sample(list(df_complete.index), 20)
else:
    indexes_random = df_complete.index

In [None]:

# Extract the matrices for time_series and LV_traj
cr_matrices = [df_filtered['N_cr'].iloc[i] for i in indexes_random]
lv_matrices = [df_filtered['N_lv'].iloc[i] for i in indexes_random]
sp_matrices = [df_filtered['N_sp'].iloc[i] for i in indexes_random]

# Create a figure with subplots for each of the matrices (23 rows, 20 columns)
if len(df_complete)==1:
    fig, axes = plt.subplots(len(df_complete), 2, figsize=(100, 80))  # Adjusted figsize for more columns
    # Loop over the axes and the matrices to plot
    for idx, (ax_cr, ax_lv, ax_sp) in enumerate(zip(axes[:, 0::3].flatten(), axes[:, 1::3].flatten(), axes[:, 2::3].flatten())):

        # Get the colors for each species based on the index
        species_colors_time_series = plt.cm.Set2.colors[:8]

        # Plot time series
        for i, color in enumerate(species_colors_time_series):
            ax_cr.plot(np.array(cr_matrices[idx])[0,:,i], color=color, label=f'Species {i + 1}')
        ax_cr.set_title('wm_CR')
        ax_cr.axis('on')  # Hide axis

        # Plot wm_LV
        for i, color in enumerate(species_colors_time_series):
            ax_lv.plot(lv_matrices[idx][0,i,:], color=color, label=f'Species {i + 1}')
        ax_lv.set_title('wm_LV')
        ax_lv.axis('on')  # Hide axis

        # Plot sp
        for i, color in enumerate(species_colors_time_series):
            ax_sp.plot(sp_matrices[idx][0,:,i], color=color, label=f'Species {i + 1}')
        ax_sp.set_title('sp')
        ax_sp.axis('on')  # Hide axis

    # Adjust layout
    plt.tight_layout()

    # Show the plot
    plt.show()
else:
    if len(df_complete)>=20:
        fig, axes = plt.subplots(5, 8, figsize=(100, 80))  # Adjusted figsize for more columns
    else:
        fig, axes = plt.subplots(len(df_complete), 2, figsize=(100, 80))  # Adjusted figsize for more columns

    for idx, (ax_cr, ax_lv, ax_sp) in enumerate(zip(axes[:, 0::3].flatten(), axes[:, 1::3].flatten(), axes[:, 2::3].flatten())):

        # Get the colors for each species based on the index
        species_colors_time_series = plt.cm.Set2.colors[:8]

        # Plot time series
        for i, color in enumerate(species_colors_time_series):
            ax_cr.plot(np.array(cr_matrices[idx])[0,:,i], color=color, label=f'Species {i + 1}')
        ax_cr.set_title('wm_CR')
        ax_cr.axis('on')  # Hide axis

        # Plot wm_LV
        for i, color in enumerate(species_colors_time_series):
            ax_lv.plot(lv_matrices[idx][0,i,:], color=color, label=f'Species {i + 1}')
        ax_lv.set_title('wm_LV')
        ax_lv.axis('on')  # Hide axis

        # Plot sp
        for i, color in enumerate(species_colors_time_series):
            ax_sp.plot(sp_matrices[idx][0,:,i], color=color, label=f'Species {i + 1}')
        ax_sp.set_title('sp')
        ax_sp.axis('on')  # Hide axis

    # Adjust layout
    plt.tight_layout()

    # Show the plot
    plt.show()

## Visualize some grids

In [None]:
from matplotlib.colors import ListedColormap, BoundaryNorm

# Extract the matrices for time_series and LV_traj
grids = [np.argmax(np.array(df_complete['grid_sp'].iloc[i])[0],axis=2) for i in indexes_random]

# Define the custom colormap matching the colors for species
species_colors = plt.get_cmap('Set2').colors[:8]  # Ensure you have 8 colors
cmap = ListedColormap(species_colors)
norm = BoundaryNorm(np.arange(9)-.5, 8)

# Create a figure with subplots for each of the matrices (23 rows, 20 columns)
if len(df_complete)==1:
    fig, axes = plt.subplots(len(df_complete), 2, figsize=(100, 80))  # Adjusted figsize for more columns
# Loop over the axes and the matrices to plot
    for idx, (ax_grid) in enumerate(axes.flatten()):

        # Get the colors for each species based on the index
        species_colors_time_series = plt.cm.Set2.colors[:8]

        # Plot time series
        for i, color in enumerate(species_colors_time_series):
            print(type(ax_grid))
            im=ax_grid.imshow(np.array(grids[idx]), cmap=cmap, norm=norm, interpolation='None')
        ax_grid.axis('off')  # Hide axis

    # Adjust layout
    plt.tight_layout()

    # Show the plot
    plt.show()
else:
    if len(df_complete)>=20:
        fig, axes = plt.subplots(5, 8, figsize=(100, 80))  # Adjusted figsize for more columns
    else:
        fig, axes = plt.subplots(len(df_complete), 2, figsize=(100, 80))  # Adjusted figsize for more columns

    # Loop over the axes and the matrices to plot
    for idx, (ax_grid) in enumerate(axes.flatten()):

        # Get the colors for each species based on the index
        species_colors_time_series = plt.cm.Set2.colors[:8]

        # Plot time series
        for i, color in enumerate(species_colors_time_series):
            print(type(ax_grid))
            im=ax_grid.imshow(np.array(grids[idx]), cmap=cmap, norm=norm, interpolation='None')
        ax_grid.axis('off')  # Hide axis

    # Adjust layout
    plt.tight_layout()

    # Show the plot
    plt.show()