In [3]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.optim import Adam
from torchmin import minimize

In [5]:
# Parameters
num_patients = 200

#unknown parameters
mean_test_score = 40 # fixed intercept
b_1 = -8.0 # effect of affected family
b_2 = 0.0 # effect of sex 
noise_var = 0.1 # variance of noise
var_time_effect = 1.0
var_intercept = 1.0

def generate_data(num_patients):
    data = []
    ground_truth = []
    for i in range(num_patients):
        time_points = []
        while len(time_points) < 5:
            length = np.random.uniform(4, 10)
            time_points = np.sort([np.random.uniform(0, length) for i in range(int(length) + 2)])
            time_points = time_points[np.diff(np.concatenate(([0.], time_points))) > 0.25]
        
        # draw fixed effect characteristics
        family_affected = np.random.binomial(1, 0.6)  # 1 for family affected, 0 for not affected
        sex = np.random.binomial(1, 0.5)  # 1 for male, 0 for female   
    
        # draw random effects
        time_effect = np.random.normal(0.0, np.sqrt(var_time_effect))
        random_intercept = np.random.normal(0.0, np.sqrt(var_intercept))

        for j,t in enumerate(time_points):
            # draw noise
            noise = np.random.normal(0.0, np.sqrt(noise_var))

            # simulate patient trajectories
            score = mean_test_score + \
                    b_1 * family_affected + \
                    b_2 * sex + \
                    random_intercept + \
                    time_effect * t + \
                    noise
                
        data.append([i, t, family_affected, sex, score])
        ground_truth.append([i, t, random_intercept, time_effect, noise])
    return data

df =  pd.DataFrame(generate_data(num_patients), columns=['patient_id', 'years_after_treatment', 'family_affected', 'sex', 'test_score'])
print(df)

     patient_id  years_after_treatment  family_affected  sex  test_score
0             0               6.934165                1    1   23.412482
1             1               6.862036                1    0   19.463645
2             2               5.873720                1    1   43.923450
3             3               9.018667                0    0   37.071385
4             4               5.857655                1    1   38.174593
..          ...                    ...              ...  ...         ...
195         195               6.390530                1    0   27.337463
196         196               8.760204                0    0   43.285878
197         197               7.097517                0    1   52.116599
198         198               7.152390                1    1   30.861974
199         199               9.327538                0    1   29.201576

[200 rows x 5 columns]
