# Deconvolve simulated data with linear functions, vary number of samples

In [None]:
import torch
from ternadecov.simulator import *
from ternadecov.time_deconv import *

# Configure

In [None]:
device = torch.device("cuda:0")
dtype = torch.float32
dtype_np = np.float32

# Load data

In [None]:
bulk_anndata_path = "/home/nbarkas/disk1/work/deconvolution_method/datasets/ebov/load_data_python/ebov_bulk.h5ad"
sc_anndata_path = "/home/nbarkas/disk1/work/deconvolution_method/datasets/ebov/load_data_python/ebov_sc.h5ad"

In [None]:
with open(bulk_anndata_path, 'rb') as fh:
    bulk_anndata  = anndata.read_h5ad(fh)
with open(sc_anndata_path, 'rb') as fh:
    sc_anndata = anndata.read_h5ad(fh)

In [None]:
# select samples only after or on tp 0
bulk_anndata = bulk_anndata[bulk_anndata.obs['dpi_time'] >= 0,]

In [None]:
ebov_dataset = DeconvolutionDataset(
    sc_anndata = sc_anndata,
    sc_celltype_col = "Subclustering_reduced",
    bulk_anndata = bulk_anndata,
    bulk_time_col = "dpi_time",
    dtype_np = dtype_np,
    dtype = dtype,
    device=device,
    feature_selection_method = 'overdispersed_bulk_and_high_sc' #'overdispersed_bulk'
)

# Run Deconvolution

In [None]:
pseudo_time_reg_deconv = TimeRegularizedDeconvolution(
    dataset=ebov_dataset,
    polynomial_degree = 10,
    basis_functions = "polynomial",
    device=device,
    dtype=dtype)

In [None]:
pseudo_time_reg_deconv.fit_model(n_iters=5_001, verbose=True, log_frequency=1000)

# Examine Outputs

In [None]:
# plot the losses
pseudo_time_reg_deconv.plot_loss()

In [None]:
# calculate and plot composition trajectories
pseudo_time_reg_deconv.calculate_composition_trajectories(n_intervals = 1000)
pseudo_time_reg_deconv.plot_composition_trajectories()

In [None]:
# examine the per-gene dispersions
pseudo_time_reg_deconv.plot_phi_g_distribution()

In [None]:
# examine the gene capture coefficients
pseudo_time_reg_deconv.plot_beta_g_distribution()
matplotlib.pyplot.yscale('log')

# Simulation -- Vary number of samples

In [None]:
# trajectory generating function
trajectory_type = 'periodic'

n_samples = list(range(10,100,10)) 
n_samples

In [None]:
# Use a single trajectory for all iterations
trajectory_coef = sample_trajectories(
    type = trajectory_type,
    num_cell_types = pseudo_time_reg_deconv.w_hat_gc.shape[1]
)

In [None]:
df_n = []
l1_error = []
shape_l1_error = []

a = 10

for n in n_samples:
    
    sim_res = simulate_data(
        num_samples=n, 
        num_cell_types = pseudo_time_reg_deconv.dataset.num_cell_types,
        num_genes = pseudo_time_reg_deconv.dataset.num_genes,
        w_hat_gc = torch.Tensor(pseudo_time_reg_deconv.dataset.w_hat_gc),
        trajectory_type=trajectory_type, 
        dirichlet_alpha = a,
        trajectory_coef = trajectory_coef
    )
    
    plot_simulated_proportions(sim_res)
    
    simulated_bulk = generate_anndata_from_sim(
        sim_res, 
        reference_deconvolution=pseudo_time_reg_deconv)
    
    ebov_simulated_dataset = DeconvolutionDataset(
        sc_anndata = sc_anndata,
        sc_celltype_col = "Subclustering_reduced",
        bulk_anndata = simulated_bulk,
        bulk_time_col = "time",
        dtype_np = dtype_np,
        dtype = dtype,
        device = device,
        feature_selection_method = 'common' 
    )
    
    pseudo_time_reg_deconv_sim = TimeRegularizedDeconvolution(
        dataset=ebov_simulated_dataset,
        polynomial_degree = 10,
        basis_functions = "polynomial",
        device=device,
        dtype=dtype)
    
    pseudo_time_reg_deconv_sim.fit_model(n_iters=3_001, verbose=True, log_frequency=1000)
    
    errors = calculate_prediction_error(sim_res, pseudo_time_reg_deconv_sim)
    
    df_n.append(n)
    l1_error.append(errors['L1_error_norm'])
    shape_l1_error.append(errors['shape_L1_error'])
    
    pseudo_time_reg_deconv_sim.plot_composition_trajectories()

In [None]:
# Plot L1 loss
error_df = pd.DataFrame({'n':df_n, 'l1': list(x.item() for x in l1_error)})
error_df.plot(x='n',y='l1')

In [None]:
# Plot Dirichlet alpha values

alphas = list(pseudo_time_reg_deconv.param_store_hist[i]['dirichlet_alpha'] for i in range(len(pseudo_time_reg_deconv.param_store_hist)))

matplotlib.pyplot.plot(alphas[-500:])
matplotlib.pyplot.title(r'Dirichlet $ \alpha $ Values')
matplotlib.pyplot.show()