In [None]:
import matplotlib.pyplot as plt
import numpy

# Helpers
from dataclasses import dataclass
import pytest

def is_convergent(values: list[float]) -> bool:
    derivative_first = list( numpy.gradient( values ))
    derivative_second = list( numpy.gradient( values ))

    N = len(values) - 1
    angle_nth = derivative_first[N]
    curvature_nth = derivative_second[N]
    return is_zero(angle_nth) and is_zero(curvature_nth)

def is_zero(value) -> bool:
    limit = 5e-7
    return limit > abs(value)

In [None]:
@dataclass(frozen=True)
class Population:

    generation: int
    value: float
    growth_rate: float
    init: float
    history: list[float]

    @classmethod
    def new(cls, r: float, x0: float) -> 'Population':
        return Population(
            generation = 1,
            value = x0,
            growth_rate = r,
            init = x0,
            history = [x0],
        )
    
    # Population
    
    def next_generation(self) -> 'Population':
        r: float = self.growth_rate
        x: float = self.value
        x1 = r * x * (1 - x)

        history = self.history + [x1]
        generation = len(history)
        
        return Population(
            generation = generation,
            value = history[ generation - 1 ],
            growth_rate = self.growth_rate,
            init = history[0],
            history = history,
        )

    def next_nth_generations(self, n: int) -> 'Population':
        N = n + self.generation
        return self.until_nth_generation(N)
    
    def until_nth_generation(self, N: int) -> 'Population':
        while self.generation < N:
            self = self.next_generation()
        return self

    # Trend and convergence

    def is_convergent(self) -> bool:
        return is_convergent(self.history)
    
    def grow_until_stable(self, limit=5000) -> 'Population':
        self = self.until_nth_generation(5)
        while self.is_convergent() == False and self.generation < limit:
            self = self.next_nth_generations(10)
        return self

    # Helpers

    def print(self):
        print(self.__dict__)
        self

    def validate(self, dict): 
        assert self.growth_rate == dict['growth_rate']
        assert self.init == dict['init']
        assert self.value == dict['value']
        assert self.history == dict['history']
        assert self.generation == dict['generation']


# TESTS

population = Population.new(r=1.0, x0 = 0.5)
population.validate({'generation': 1, 'value': 0.5, 'growth_rate': 1.0, 'init': 0.5, 'history': [0.5]})

population = population.next_generation()
population.validate({'generation': 2, 'value': 0.25, 'growth_rate': 1.0, 'init': 0.5, 'history': [0.5, 0.25]})

population = population.until_nth_generation(15)
population.validate({'generation': 15, 'value': 0.05357062532685648, 'growth_rate': 1.0, 'init': 0.5, 'history': [0.5, 0.25, 0.1875, 0.15234375, 0.1291351318359375, 0.11245924956165254, 0.09981216674968249, 0.08984969811841606, 0.08177672986644556, 0.07508929631879595, 0.06945089389714401, 0.06462746723403166, 0.06045075771294583, 0.05679646360487655, 0.05357062532685648]})

population = population.next_nth_generations(10)
population.validate({'generation': 25, 'value': 0.03433841067421475, 'growth_rate': 1.0, 'init': 0.5, 'history': [0.5, 0.25, 0.1875, 0.15234375, 0.1291351318359375, 0.11245924956165254, 0.09981216674968249, 0.08984969811841606, 0.08177672986644556, 0.07508929631879595, 0.06945089389714401, 0.06462746723403166, 0.06045075771294583, 0.05679646360487655, 0.05357062532685648, 0.05070081342894604, 0.04813024094658925, 0.04581372085301251, 0.04371482383461476, 0.041803838011723354, 0.04005627713921295, 0.03845177180095951, 0.036973233046326444, 0.035606213084428476, 0.03433841067421475]})

assert population.is_convergent() == False

population = population.grow_until_stable(limit=5000)
assert population.is_convergent() == True
assert population.generation == 1415
assert population.value == 0.0007027267701217214

### Plot functions

In [None]:
def plot_population_over_time(population: Population):
    x = range(population.generation)
    y = population.history
        
    plt.plot(x, y)

def plot_convergence_values(list_population: list[Population]):
    x = []
    y = []

    for i in range(len(list_population)):
        population = list_population[i]
        x = x + [population.growth_rate]
        y = y + [population.value]

    plt.scatter(x, y)

In [None]:
plot_population_over_time(population)

## Calculate populations with different growth rate

In [None]:
step = 0.01
r_range = numpy.arange(1.0, 4.0, step) 
x = 0.5

list_population = []

for r in r_range:
    print(r)
    result = Population.new(r = r, x0 = x).until_nth_generation(2000)
    list_population = list_population + [result]

## Plot convergence values

In [None]:
plot_convergence_values(list_population)

## Population vs Generation vs Growth rate

In [None]:
%matplotlib inline
import pylab as pl
from IPython import display

gen_limit = 50

for population in list_population:
    x = range(gen_limit)
    y = population.history[0:gen_limit]

    display.clear_output(wait=True)
    plt.ylim([0.0, 1.0])
    plt.plot(x, y)
    plt.xlabel('generation')
    plt.ylabel('population')
    plt.legend([f'growth rate = {population.growth_rate}'], loc='upper left')
    plt.show()