In [1]:
from get_models import Progress_Bar, Encoder, Decoder, CovarianceMatrix, thermometer_encode_df

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import scipy.stats as stats
from scipy.stats import chi2

#Load dataframes: test_scores has high dimensional test score data of hfsme tests
test_scores_df = pd.read_csv(os.getcwd()+'/test_scores.csv')

#test_scores_df_encoded is a thermometer encoding of test_scores_df for the encoder network
test_scores_df_encoded = thermometer_encode_df(test_scores_df, test_scores_df.columns[1:])

#time_df contains data that changes with time, e.g.: age or time since medication switch
time_df = pd.read_csv(os.getcwd()+'/time_df.csv')
time_df['intercept'] = np.ones(time_df.shape[0])

#baseline_df contains features that characterizes patients at baseline
baseline_df = pd.read_csv(os.getcwd()+'/baseline_df.csv')

# 'sex' has no influence:
baseline_df['sex'] = np.random.randint(2, size=baseline_df.shape[0])

df_effects = pd.merge(baseline_df, time_df, on='patient_id', how='inner')

fixed_effects_keys = ['family_affected', 'sco_surg', '≤3', 'onset_age', 'presym_diag', 'presymptomatic', 'stand_lost', 'stand_gained', 'stand_never', 'sex']
random_effects_keys = ['intercept', 'since_medication', 'since_switch']

p = len(fixed_effects_keys)
q = len(random_effects_keys)

#vae latent dimension
latent_dim = 2

def get_ind(id, df):
    return np.where(df['patient_id'] == id)[0]

def get_design_matrix(df_effects, fixed_effects_keys, random_effects_keys, r=1, include_interaction=False):
    patient_id = df_effects['patient_id'].unique()

    X_list = [torch.from_numpy(np.array(df_effects.loc[get_ind(id, df_effects), fixed_effects_keys])).to(torch.float32) for id in patient_id]

    if include_interaction==True:
        for key in random_effects_keys[1:]:
            X_list = [torch.cat((X_i, X_i[:,1:] * torch.from_numpy(np.array(df_effects.loc[get_ind(patient_id[j], df_effects), key])).unsqueeze(-1)
                                ), -1).to(torch.float32) for j,X_i in enumerate(X_list)]

    X_list = [torch.cat((torch.from_numpy(np.array(df_effects.loc[get_ind(patient_id[j], df_effects), 'age'])).unsqueeze(-1), X_i), -1).to(torch.float32) for j,X_i in enumerate(X_list)]


    Z_list = [torch.from_numpy(np.array(df_effects.loc[get_ind(id, df_effects), random_effects_keys])).to(torch.float32) for id in patient_id]
    Z_list = [torch.block_diag(*[i for j in range(r)]) for i in Z_list]   
    X_list = [torch.block_diag(*[i for j in range(r)]) for i in X_list]
    return X_list, Z_list

#get the design matrices
X_list, Z_list = get_design_matrix(df_effects, fixed_effects_keys, random_effects_keys, r=latent_dim, include_interaction=False)

pat_ind = np.cumsum([0]+[int(len(X_i)/latent_dim) for X_i in X_list])
# initialize Encoder and Decoder Models and the Mixed Model Parameters. mode='diagonal': Diagonal Covariance Matrix, mode='full': Full Covariance Matrix,
def initialize(latent_dim, mode='diagonal'):
    encoder = Encoder(
        input_dim=np.shape(test_scores_df_encoded)[-1],
        hidden_dims=[150], 
        output_dim=latent_dim, 
        act=torch.nn.Tanh())

    decoder = Decoder(
        item_positions=np.concatenate([[i]*a for i,a in enumerate(np.array(test_scores_df[test_scores_df.columns[1:]].max(0)).astype(np.int32))]),                            
        input_dim=latent_dim,
        hidden_dims=[150], 
        act=torch.nn.Tanh())

    var_param = CovarianceMatrix(q*latent_dim, mode=mode)    
    return encoder, decoder, var_param
patients = torch.from_numpy(np.array(baseline_df['patient_id']))
num_patients = len(patients)
def eval_vae(X_list, Z_list, var_param, encoder):
    with torch.no_grad():
        pat_ind_batch = [torch.arange(pat_ind[i],pat_ind[i+1]) for i in patients]
        prior = Normal(torch.zeros(torch.Size([latent_dim])), torch.ones(torch.Size([latent_dim])))

        test_data = torch.concatenate([torch.from_numpy(np.array(test_scores_df_encoded.loc[ind])).to(torch.float32) for ind in pat_ind_batch])

        mu, log_sig = encoder.encode(test_data)

        eps = prior.sample(torch.Size([log_sig.size(dim=0)])) 
        z = mu + log_sig.exp() * eps

        z_list = [z[ind].flatten().to(torch.float32) for ind in pat_ind_batch]

        Phi, sigma = var_param()
        N = sum([len(Z_i) for Z_i in Z_list])

        V_list = [Z_i @ Phi @ Z_i.t() + torch.eye(Z_i.size(0)) * sigma for Z_i in Z_list]
        V_inv_list = [V_i.inverse() for V_i in V_list]
        
        Xt_V_inv_X = torch.stack([X_i.t() @ V_i_inv @ X_i for X_i, V_i_inv in zip(X_list, V_inv_list)]).sum(dim=0)
        Xt_V_inv_y = torch.stack([X_i.t() @ V_i_inv @ y_i for X_i, V_i_inv, y_i in zip(X_list, V_inv_list, z_list)]).sum(dim=0)

        EBLUE = Xt_V_inv_X.inverse() @ Xt_V_inv_y
        EBLUP_list = [Phi @ Z_i.t() @ V_i_inv @ (y_i - X_i @ EBLUE) for X_i, Z_i, V_i_inv, y_i in zip(X_list, Z_list, V_inv_list, z_list)]

        residual_list = [y_i - X_i @ EBLUE for y_i, X_i in zip(z_list, X_list)]
        z_pred = torch.cat([X_i @ EBLUE + Z_i @ EBLUP_i for X_i, Z_i, EBLUP_i in zip(X_list, Z_list, EBLUP_list)]).reshape((-1, latent_dim))

        log_det_V = torch.stack([V_i.det().clamp(min=1e-12).log() for V_i in V_list]).sum()
        const = torch.log(torch.tensor(2.0 * torch.pi))
        rt_V_inv_r = torch.stack([r_i.t() @ V_i_inv @ r_i for r_i, V_i_inv in zip(residual_list, V_inv_list)]).sum()

        nML = 0.5 * (log_det_V + rt_V_inv_r + N * const) 
        return mu, z, z_pred, nML

def likelihood_ratio(L_full, L_red):
    return 2 * (L_full - L_red)


fixed_effects_keys_full = ['family_affected', 'sco_surg', '≤3', 'onset_age', 'presym_diag', 'presymptomatic', 'stand_lost', 'stand_gained', 'stand_never', 'sex']
random_effects_keys_full = ['intercept', 'since_medication', 'since_switch']
# reduced model without fixed effect 'sex' 
fixed_effects_keys_red = ['family_affected', 'sco_surg', '≤3', 'onset_age', 'presym_diag', 'presymptomatic', 'stand_lost', 'stand_gained', 'stand_never']
random_effects_keys_red = ['intercept', 'since_medication', 'since_switch']

# get design matrix for the full model
X_list_full, Z_list_full = get_design_matrix(df_effects, fixed_effects_keys_full, random_effects_keys_full, r=latent_dim)
# get design matrix for the reduced model
X_list_red, Z_list_red = get_design_matrix(df_effects, fixed_effects_keys_red, random_effects_keys_red, r=latent_dim)
encoder_full, decoder_full, var_param_full = initialize(latent_dim) 
optimizer_vae_full = torch.optim.Adam([
        {'params': var_param_full.parameters(), 'lr': 0.1},  
        {'params': encoder_full.parameters(), 'lr': 0.01},  
        {'params': decoder_full.parameters(), 'lr': 0.01},  
])
encoder_red, decoder_red, var_param_red = initialize(latent_dim) 
optimizer_vae_red  = torch.optim.Adam([
        {'params': var_param_red.parameters(), 'lr': 0.1},  
        {'params': encoder_red.parameters(), 'lr': 0.01},  
        {'params': decoder_red.parameters(), 'lr': 0.01},  
])
# Lade die gespeicherten Daten
#all_epochs_info = torch.load(r'C:\Users\yanni\OneDrive\Desktop\BachelorArbeit2024\Code\trained_models\complete_models_all_epochs.pth')
lrt_results = []
for i in range(0,38):
    #epoch_to_load = f'epoch_{i}'
    #loaded_data = all_epochs_info[epoch_to_load]
    loaded_data = torch.load(fr'C:\Users\yanni\OneDrive\Desktop\BachelorArbeit2024\Code\trained_models\models_epoch_{i}.pth')

    # Verwendung der geladenen Daten
    encoder_full.load_state_dict(loaded_data['encoder_full_state_dict'])
    decoder_full.load_state_dict(loaded_data['decoder_full_state_dict'])
    var_param_full = loaded_data['var_param_full']
    optimizer_vae_full.load_state_dict(loaded_data['optimizer_vae_full_state_dict'])
    encoder_red.load_state_dict(loaded_data['encoder_red_state_dict'])
    decoder_red.load_state_dict(loaded_data['decoder_red_state_dict'])
    var_param_red = loaded_data['var_param_red']
    optimizer_vae_red.load_state_dict(loaded_data['optimizer_vae_red_state_dict'])
    X_list_full = loaded_data['X_list_full']
    X_list_red = loaded_data['X_list_red']
    Z_list_full = loaded_data['Z_list_full']
    Z_list_red = loaded_data['Z_list_red']

    nML_full = eval_vae(X_list_full,Z_list_full, var_param_full, encoder_full)[3]
    nML_red = eval_vae(X_list_red, Z_list_red, var_param_red, encoder_red)[3]

    lrt_results.append(likelihood_ratio(nML_full,nML_red))

# plot the histogr<am of the LRT-statistic
plt.hist(lrt_results, bins=50, density=True, alpha=0.6, color='g', label='Histogramm für 35 Simulationen')

#x = np.linspace(-8000, 8000, 100)
#plt.plot(x, chi2.pdf(x + 500, df=1), 'r-', lw=2, label='Chi-Quadrat-Verteilung (df=1)')

# Beschriftungen hinzufügen
plt.title('Histogramm der Likelihood-Ratio-Teststatistiken')
plt.xlabel('Teststatistik')
plt.ylabel('Häufigkeit')
plt.legend()

# Histogramm anzeigen
plt.show()

RuntimeError: Error(s) in loading state_dict for Encoder:
	size mismatch for mu.weight: copying a param with shape torch.Size([1, 150]) from checkpoint, the shape in current model is torch.Size([2, 150]).
	size mismatch for mu.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([2]).
	size mismatch for log_sig.weight: copying a param with shape torch.Size([1, 150]) from checkpoint, the shape in current model is torch.Size([2, 150]).
	size mismatch for log_sig.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([2]).