<a href="https://colab.research.google.com/github/leonardogolinelli/rna_velo_tests/blob/main/generate_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.append('/content/drive/My Drive/')
import data_simulator as sim_data

In [None]:
def init_parameters():
  t_final = 10
  alfa = np.random.uniform(low=1, high=10, size=2000)
  beta = np.random.uniform(low=0.5, high=2, size=2000)
  gamma = np.random.uniform(low=0.1, high=0.99, size=2000)
  #relative_t_switch = np.random.uniform(low=0.01, high=0.99, size=2000)
  relative_t_switch = np.random.beta(5,5, 2000)
  n_simulated_data = 1500

  return t_final, alfa, beta, gamma, relative_t_switch, n_simulated_data

In [None]:
def simulate_data(t_final, alfa, beta, gamma, relative_t_switch, n_simulated_data):
  t, t_switch, exp_u, exp_s, u_stoch, s_stoch = sim_data.module(u0 = 0, s0 = 0, alfa = alfa, beta = beta, gamma = gamma, t_final = t_final,
                                                                n_simulated_data = n_simulated_data, relative_t_switch = relative_t_switch,
                                                                periodic_alfa = False)

  return t, t_switch, exp_u, exp_s, u_stoch, s_stoch

In [None]:
def format_data(unspliced_raw, spliced_raw, dic,ind,t, t_switch, exp_u, exp_s, u_stoch, s_stoch, alfa, beta, gamma):
  dic[str(ind)] = pd.DataFrame()
  dic[str(ind)]['time'] = t
  dic[str(ind)]['switch_time'] = t_switch
  dic[str(ind)]['exp_u'] = exp_u
  dic[str(ind)]['exp_s'] = exp_s
  dic[str(ind)]['u_stoch'] = u_stoch
  dic[str(ind)]['s_stoch'] = s_stoch
  dic[str(ind)]['alfa'] = alfa
  dic[str(ind)]['beta'] = beta
  dic[str(ind)]['gamma'] = gamma


  unspliced_raw[:,ind] = u_stoch
  spliced_raw[:,ind] = s_stoch

  return dic, unspliced_raw, spliced_raw

In [None]:
def plot_simulation(dic):
  plt.figure(figsize=(12,6))
  nrows = 4
  ncols = 4
  n = nrows*ncols
  plt.subplots(nrows,ncols, figsize=(18,9))
  sub_genes = np.random.randint(0,2000,n)

  j=1
  for i in sub_genes:
    if j <= n:
      x = dic[str(i)]['time']
      exp_u = dic[str(i)]['exp_u']
      exp_s = dic[str(i)]['exp_s']

      plt.subplot(nrows,ncols,j)
      plt.plot(x, exp_u)
      plt.plot(x, exp_s)

    j +=1


plt.show()

In [None]:
def plot_phase_plane(dic):
  plt.figure(figsize=(12,6))
  nrows = 4
  ncols = 4
  n = nrows*ncols
  plt.subplots(nrows,ncols, figsize=(18,9))
  sub_genes = np.random.randint(0,2000,n)

  j=1
  for i in sub_genes:
    if j <= n:
      x = dic[str(i)]['time']
      exp_u = dic[str(i)]['exp_u']
      exp_s = dic[str(i)]['exp_s']

      plt.subplot(nrows,ncols,j)
      plt.plot(exp_u, exp_s)

    j +=1


  plt.show()

In [None]:
def generate_data(plot=False):
  t_final, alfa_seq, beta_seq, gamma_seq, relative_t_switch_seq, n_simulated_data = init_parameters()
  dic = {}
  unspliced_raw = np.zeros((3000,2000))
  spliced_raw = np.zeros((3000,2000))
  for i in range(2000):
    alfa = alfa_seq[i]
    beta = beta_seq[i]
    gamma = gamma_seq[i]
    relative_t_switch = relative_t_switch_seq[i]
    t, t_switch, exp_u, exp_s, u_stoch, s_stoch = simulate_data(t_final, alfa, beta, gamma, relative_t_switch, n_simulated_data)
    dic, unspliced_raw, spliced_raw  = format_data(unspliced_raw, spliced_raw, dic, i, t, t_switch, exp_u, exp_s, u_stoch, s_stoch, alfa, beta, gamma)

  if plot:
    plot_simulation(dic)
    plot_phase_plane(dic)

  return dic, unspliced_raw, spliced_raw

In [None]:
#dic, unspliced_raw, spliced_raw = generate_data()