### University of California, Berkeley
### Chem 274B: Software Engineering Fundamentals for Molecular Sciences 
### Final Project
### Creators:  Francine Bianca Oca, Kassady Marasigan and Korede Ogundele
### Date Created: December 5, 2023

This file contains functions and plots that pertain to our cellular automata model.

In [None]:
import random

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import plotly.express as px

# 1. Requirement to display an image of the current CA configuration

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def display_cellular_automaton(grid, cmap='viridis'):
    """
    Display an image of cellular automaton at a certain time step.

    Parameters:
    - grid: 2D array representing the cellular automaton grid
    - cmap: colormap to use (default is 'viridis')

    Returns:
    - None (displays the image)
    """
    # Create a custom colormap with distinct colors for each state
    unique_states = np.unique(grid)
    colors = plt.cm.get_cmap(cmap, len(unique_states))

    # Create a dictionary to map state values to colors
    state_colors = {state: colors(i) for i, state in enumerate(unique_states)}

    # Create an array of colors based on the grid values
    color_array = np.vectorize(state_colors.get)(grid)

    # Display the image using plt.imshow
    plt.imshow(color_array, interpolation='nearest', cmap=cmap, origin='upper')

    # Display colorbar with state labels
    cbar = plt.colorbar(ticks=unique_states, boundaries=np.arange(len(unique_states) + 1) - 0.5)
    cbar.set_ticklabels(unique_states)
    cbar.set_label('Cell States')

    # Show the plot
    plt.show()

# Example usage:
# Assuming 'grid' is a 2D array representing the cellular automaton at a specific time step
# Replace it with your actual grid data
grid = np.random.randint(0, 3, size=(10, 10))  # Example random grid
display_cellular_automaton(grid)



# 2. Requirement to count the number of cells in the CA that are in a particular state

idk if you will use grid argument. We'll see

In [None]:
def display_state_count(grid, state, step_number):
    """
    Displays the number of individuals in a given state at a specific step.

    Parameters
    ----------
    grid : 
        A 2D array representing the cellular automata grid.
    state : int
        the state to count individuals for (1: homozygous dominant, 2: heterozygous, 3: recessive)
    step_number : int
        the step number in the simulation.

    Returns
    -------
    state : str
        state to count individuals for
    state_count : int
        number of individuals in given state
    """

    count = 0
    for row in grid:
        for cell in row:
        if cell == state:
            count += 1

    print(f"Step {step_number}: Number of individuals in state {state}: {count}")

# Example usage
grid = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
state = 1
step_number = 10

display_state_count(grid, state, step_number)


# 3. Bar graph that displays allele frequency of a chosen generation

In [None]:
def plot_state_counts(step_number):
    """
    This function plots the number of each individuals of each state at a given step on a bar graph.

    Parameters
    ----------
    step_number : int
    The step number we would like to view states for

    Returns
    -------
    A bar graph with states and their proportion of the total population

    """
    # open text file containing state counts.
    # FILE NAME AND PATH WILL VARY DEPENDING ON THE SPECIFIC APPLICATION
    with open('Data/filename.txt', 'r') as file:
        lines = file.readlines()

    # extract counts for the specified step number
    for line in lines:
        if line.startswith(f'Step {step_number}:'):
            counts_line = line.strip().split(', ')
            state1 = int(counts_line[0].split('=')[-1])
            state2 = int(counts_line[1].split('=')[-1])
            state3 = int(counts_line[2].split('=')[-1])

            # get counts as a percentage of total population
            total_population = GG_count + Gg_count + gg_count
            homozygous_dominant_percentage = GG_count / total_population
            heterozygous_percentage = GG_count / total_population
            recessive_percentage = gg_count / total_population
            
            break
    else:
        # return an error if generation number is not found
        print(f"ERROR: Generation {generation_number} not found.")
        return

    # Plot
    labels = ['Homozygous Dominant', 'Heterozygous', 'Recessive']
    counts = [homozygous_dominant_percentage, heterozygous_percentage, recessive_percentage]

    plt.bar(labels, counts, color=['royalblue', 'mediumslateblue', 'mediumorchid'])
    plt.title(f'Allele Percentages for Generation {generation_number}')
    plt.xlabel('Genotype')
    plt.ylabel('Proportion of Population')
    plt.show()

# 4. Filled area plot to visualize how allele frequency changes over 100 generations