In [None]:
import pandas as pd
import numpy as np
import json
import pymc as pm
import arviz as az
az.style.use("arviz-doc")

from src.plot_functions import plot_screen_events, plot_posterior_sleep_wake, plot_logp, plot_DIC, plot_sleep_duration
from src.model_functions import is_asleep, create_time_bins, calculate_DIC, run_model

from src.utils import get_time_from_bin, extract_user_data

# Synthetic data generation

In [None]:
# Import the generate_synthetic_data function
from data_generation.data_generation import generate_synthetic_data

# Generate the synthetic data
df = generate_synthetic_data()

# Display the first few rows
df.head()

# Extract user data

In [None]:
n_days, bin_hours, observed_event_counts = extract_user_data(df, user='user_001')
n_bins, total_bins, time_bins = create_time_bins(n_days=n_days)

In [None]:
# Plot the screen events
plot_screen_events(observed_event_counts=observed_event_counts, n_days=n_days)


# Models

model names: 'pooled_pooled', 'independent_pooled', 'independent_independent', 'hyper_hyper', 'indipendent_hyper'

In [None]:
# Usage example 
# 'pooled_pooled', 'independent_pooled', 'independent_independent', 'indipendent_hyper', 'hyper_hyper'
model_name = 'hyper_hyper'  
trace, posterior_predictive, log_likelihood, logp = run_model(
    model_name, observed_event_counts, n_bins, n_days, total_bins, time_bins
)

In [None]:
az.summary(trace, var_names=['tsleep', 'tawake', 'lambda_sleep', 'lambda_awake'])


In [None]:
az.plot_trace(trace, legend=False)


#### For model selection see model_selection.py


# Extract sleep information

In [None]:
tsleep_samples = trace['posterior']['tsleep'].mean(dim=['chain']).values.flatten()
tawake_samples = trace['posterior']['tawake'].mean(dim=['chain']).values.flatten()
plot_posterior_sleep_wake(tsleep_samples, tawake_samples)


In [None]:
sleep_times = [int(i) for i in trace['posterior']['tsleep'].mean(dim=['chain', 'draw']).values]
awake_times = [int(i) for i in trace['posterior']['tawake'].mean(dim=['chain', 'draw']).values]
sleep_hours = [(x-y)/4 for x, y in zip(awake_times, sleep_times)]
plot_sleep_duration(sleep_hours)    