In [349]:
import numpy as np
import os
import sys
import copy

In [355]:
init_params = {'n_CpGs':500, 'n_cells':1,
               'flip_rate':0.01, # flip rate per cell division
               'growth_rate':0.2, # cell divisions per day
               'death_rate':0.14, # cell deaths per day,
               'init_site_state_probs':[1, 0, 0],
              }

In [356]:
gen = np.random.default_rng()

In [357]:
class Cell:
    def __init__(self, state, next_cell=None):
        self.state = state
        self.next_cell = next_cell
    def flip(self):
        for i in [1, 2]:
            if gen.uniform() < init_params['flip_rate']:
                self.state ^= i
    def getDaughter(self):
        daughter = Cell(self.state)
        daughter.flip()
        return daughter
    def numMethAlleles(self):
        return (self.state & 1) + ((self.state & 2) > 0)
    
def getBetaValue(first_cell):
    n_cells = 0
    meth_alleles = 0

    cur_cell = first_cell
    while cur_cell.next_cell is not None:
        n_cells += 1
        meth_alleles += cur_cell.next_cell.numMethAlleles()
        cur_cell = cur_cell.next_cell

    beta = meth_alleles / (2 * n_cells)
    return beta, n_cells

def initCells(init_params):
    first_cell = Cell(-1)
    cur_cell = first_cell

    for i in range(init_params['n_cells']):
        state = gen.choice([0, 1, 3], p=init_params['init_site_state_probs'])
        cur_cell.next_cell = Cell(state)
        cur_cell = cur_cell.next_cell

    return first_cell

def passDay(first_cell):
    """
    Cumulative probabilities
    First range - divide
    Second range - die
    Third range - stay
    """
    new_first_cell = Cell(-1)
    new_last_cell = new_first_cell

    cur_cell = first_cell

    while cur_cell.next_cell is not None:
        event_prob = gen.uniform()
        if event_prob < init_params['growth_rate']:
            new_last_cell.next_cell = cur_cell.next_cell.getDaughter()
            new_last_cell = new_last_cell.next_cell
            cur_cell = cur_cell.next_cell
        elif event_prob < init_params['growth_rate'] + init_params['death_rate']:
            cur_cell.next_cell = cur_cell.next_cell.next_cell
        else:
            cur_cell = cur_cell.next_cell

    cur_cell.next_cell = new_first_cell.next_cell

In [358]:
nyears = 1

total_days = int(nyears * 365)

first_cell_list = []
for j in range(3):
    init_params['init_site_state_probs'] = [0, 0, 0]
    init_params['init_site_state_probs'][j] = 1
    first_cell_list.append(initCells(init_params))

In [359]:
beta_list = [[], [], []]
print('Calculating...')
for i in range(total_days+1):
    for j in range(3):
        beta, n_cells = getBetaValue(first_cell_list[j])
        
        
        if i % 10 == 0:
            print(f'{i}, {j}: {n_cells}')
        
        beta_list[j].append(beta)
        passDay(first_cell_list[j])

Calculating...
0, 0: 1
0, 1: 1
0, 2: 1
10, 0: 1
10, 1: 4
10, 2: 6
20, 0: 4
20, 1: 1
20, 2: 10
30, 0: 4
30, 1: 3
30, 2: 18
40, 0: 3
40, 1: 2
40, 2: 43
50, 0: 12
50, 1: 2
50, 2: 83
60, 0: 23
60, 1: 4
60, 2: 128
70, 0: 47
70, 1: 5
70, 2: 229
80, 0: 60
80, 1: 4
80, 2: 450
90, 0: 156
90, 1: 11
90, 2: 836
100, 0: 252
100, 1: 20
100, 2: 1599
110, 0: 491
110, 1: 9
110, 2: 2911
120, 0: 895
120, 1: 16
120, 2: 5472
130, 0: 1604
130, 1: 18
130, 2: 9705
140, 0: 2910
140, 1: 11
140, 2: 17685
150, 0: 5103
150, 1: 27
150, 2: 31275
160, 0: 9044
160, 1: 44
160, 2: 55708


KeyboardInterrupt: 

In [None]:
# plot_data = pd.melt(pd.DataFrame(np.array(beta_list), columns=timestep_list).T.reset_index(), ['index'])
plot_data = pd.melt(pd.DataFrame(np.array(beta_list)).T.reset_index(), ['index'])
plot_data['Year'] = plot_data['index'].map(lambda x:timestep_list[x])

fig, ax = plt.subplots(figsize=(8, 5))
fig.tight_layout(pad=3)

plot = sns.lineplot(ax=ax, x='Year', y='value', data=plot_data,
                    hue='variable',
                    palette=['green', 'orange', 'blue'], linestyle='solid')

ax.set_ylim((-0.1, 1.1))
# ax.tick_params(
#     axis='x',          # changes apply to the x-axis
#     which='both',      # both major and minor ticks are affected
#     bottom=False,      # ticks along the bottom edge are off
#     top=False,         # ticks along the top edge are off
#     labelbottom=False) # labels along the bottom edge are off
ax.set_yticks([0, 0.5, 1])
ax.set_ylabel('$β$', rotation=0)
ax.yaxis.set_label_coords(0.0, 1.07)
# ax.set_xlabel('$time$')
ax.set_xlabel('Time (years)')
ax.legend('', frameon=False)

# ax.set_xticks(ticks=None)