In [None]:
import os, sys
sys.path.insert(0, os.path.abspath("../scr"))
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from sbi import analysis as analysis

In [None]:
from gls_spline import gls_spline
from stats_utils import build_transition_matrix, bin_trajectory
import config_real_data as config

In [None]:
from sbi.utils import MultipleIndependent
import torch.distributions as dists

In [None]:
class SplinePrior(MultipleIndependent):
    def __init__(self, dists):
        super().__init__(dists, validate_args=None, arg_constraints={})

    def sample(self, sample_shape=torch.Size([])):
        samples = super().sample(sample_shape)
        if sample_shape == torch.Size():
            samples[2:] = samples[2:] - torch.mean(samples[2:]).reshape(-1, 1)
        else:
           samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)
        return samples

In [None]:
data = np.genfromtxt("../../smFE_data/30R50T4_ConstantForce.dat")

In [None]:
data = np.load('real_traj.npy')

In [None]:
y_true = np.array([70. , 30. , 10. ,  2.1,  1.8,  3.1,  5.9,  6.6,  6.4, 10. ,  1.7,
        4.9,  8.4, 30. , 70. ])

y_true = y_true - np.min(y_true[2:-2])

In [None]:
np.random.uniform(-3, -5)

In [None]:
xquant = 0.001
data_min = np.quantile(data, xquant)
data_max = np.quantile(data, 1-xquant)

In [None]:
print(data_min, data_max)

In [None]:
def calculate_transition_matrices(trajectory):
    bins = np.linspace(config.min_bin, config.max_bin, config.num_bins + 1)
    binned_q = bin_trajectory(trajectory, bins)
    matricies = np.array([
        build_transition_matrix(binned_q, len(bins) - 1, t=lag_time)
            for lag_time in config.lag_times
            ])

    return torch.tensor(np.nan_to_num(matricies, nan=0.0)).flatten()

In [None]:
f = calculate_transition_matrices(data[:, 0]).view(6, 20, 20)
fig, axes = plt.subplots(1, 6, figsize=(20, 40))
for i in range(6):
    axes[i].imshow(f[i])

In [None]:
torch.save(torch.stack([calculate_transition_matrices(data[:, i]) for i in range(20)], dim=0), '../data/observations/real_obs_all.pt')
for i in range(20):
    torch.save(calculate_transition_matrices(data[:, i]), f'../data/observations/real_obs{i}.pt')

In [None]:
posterior.set_default_x(all_obs[0].cuda())

In [None]:
x_knots = np.linspace(config.min_x, config.max_x, config.N_knots)
x_axis = np.linspace(config.min_x, config.max_x, 1000)
samples_cat = []
for i in range(7):
    with open(f'../data/posteriors/sequential_posterior_norm_0_obs{i}_round=24.pkl', 'rb') as handle:
        posterior = pickle.load(handle)

    posterior.set_default_x(torch.load(f'../data/observations/real_obs{i}.pt'))

    samples = posterior.sample((200000,)).cpu()
    samples_cat.append(samples)
samples = torch.cat(samples_cat, dim=0)
samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)
mean_posterior = samples.mean(dim=0)
y_knots = np.zeros(config.N_knots)
y_knots[0] = config.max_G_0 + mean_posterior[2].numpy()
y_knots[-1] = config.max_G_0 + mean_posterior[-1].numpy()
y_knots[1] = config.max_G_1 + mean_posterior[2].numpy()
y_knots[-2] = config.max_G_1 + mean_posterior[-1].numpy()
y_knots[2:-2] = mean_posterior[2:].numpy()
y_axis = gls_spline(x_knots, y_knots, x_axis)
if i == 0:
    plt.plot(x_axis, y_axis, alpha=1, color='blue', label='SBI mean')
else: 
    plt.plot(x_axis, y_axis, alpha=1, color='blue')

dc = np.genfromtxt('../../smFE_data/deconvolution_new_new+.csv', delimiter=',')
plt.plot(-dc[:, 0]*0.93+531.0, dc[:, 1]+2.7, color='red', label='deconvolution')

#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l$ = {k}')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
plt.legend(loc='upper left', fontsize=14)
plt.ylim(-8, 10)
plt.xlim(490, 560)
plt.grid(True)

In [None]:
with open(f'../data/posteriors/sequential_posterior_norm_0_obs{0}_round=24.pkl', 'rb') as handle:
    posterior = pickle.load(handle)

In [None]:
posterior.set_default_x(torch.load('../data/observations/real_obs0.pt'))

In [None]:
samples = posterior.sample((50000,)).cpu()
samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)

In [None]:
sorted_log_prob = posterior.log_prob(samples.cuda()).sort()

In [None]:
_ = analysis.marginal_plot(samples.cpu(), points_colors='r', figsize=(32, 3))

In [None]:
x_knots = np.linspace(config.min_x, config.max_x, config.N_knots)
x_axis = np.linspace(config.min_x, config.max_x, 1000)
#true_y = gls_spline(x_knots, y_true, x_axis)
#plt.plot(x_axis, true_y, color='red')
for sample in samples[torch.randperm(50000)[:10]].cpu():

    y_knots = np.zeros(config.N_knots)
    y_knots[0] = 30 + sample[2].numpy()
    y_knots[-1] = 30 + sample[-1].numpy()
    y_knots[1] = 5 + sample[2].numpy()
    y_knots[-2] = 5 + sample[-1].numpy()
    y_knots[2:-2] = sample[2:].numpy()
    print(y_knots[2:-2].mean())
    #y_knots -= y_knots.mean()
    #for idx, difference in enumerate(sample[2:].numpy()):
    #    y_knots[2 + idx + 1] = y_knots[2 + idx] + difference
    #    y_knots[2: -2] -= y_knots[2: -2].mean()

    y_axis = gls_spline(x_knots, y_knots, x_axis)
    plt.plot(x_axis, y_axis, alpha=0.8, color='blue')
    plt.ylim(-10, 7)
    plt.xlim(490, 560)

dc = np.genfromtxt('../../smFE_data/deconvolution_new_new+.csv', delimiter=',')
plt.plot(-dc[:, 0]*0.98+531.1, dc[:, 1]+3.4, color='red', label='deconvolution')
#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l = {k}$')
#plt.plot(x_axis, true_y, color='red', linewidth=1.6, label='ground truth')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
#plt.legend(loc='upper left')
plt.grid(True)

In [None]:
y_knotss = np.zeros((config.N_knots, samples.shape[0]))
y_knotss[0, :] = y_knotss[-1, :] = 30
y_knotss[1, :] = y_knotss[-2, :] = 5

for idx_s, sample in enumerate(samples.cpu()):
    for idx, difference in enumerate(sample[2:].numpy()):
        y_knotss[2 + idx + 1, idx_s] = y_knotss[2 + idx, idx_s] + difference
        y_knotss[2: -2, idx_s] -= y_knotss[2: -2, idx_s].mean()

In [None]:
# Plotting the posterior mean

for i in range(5):
    with open(f'../data/posteriors/sequential_posterior_norm_{i}_round=24.pkl', 'rb') as handle:
        posterior = pickle.load(handle)

    posterior.set_default_x(torch.load('../data/observations/real_obs0.pt'))

    samples = posterior.sample((1000000,)).cpu()
    samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)
    mean_posterior = samples.mean(dim=0)
    y_knots = np.zeros(config.N_knots)
    y_knots[0] = config.max_G_0 + mean_posterior[2].numpy()
    y_knots[-1] = config.max_G_0 + mean_posterior[-1].numpy()
    y_knots[1] = config.max_G_1 + mean_posterior[2].numpy()
    y_knots[-2] = config.max_G_1 + mean_posterior[-1].numpy()
    y_knots[2:-2] = mean_posterior[2:].numpy()
    y_axis = gls_spline(x_knots, y_knots, x_axis)
    if i == 0:
        plt.plot(x_axis, y_axis, alpha=1, color='blue', label='SBI mean')
    else: 
        plt.plot(x_axis, y_axis, alpha=1, color='blue')

dc = np.genfromtxt('../../smFE_data/deconvolution_new_new+.csv', delimiter=',')
plt.plot(-dc[:, 0]*0.93+531.0, dc[:, 1]+2.7, color='red', label='deconvolution')

#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l$ = {k}')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
plt.legend(loc='upper left', fontsize=14)
plt.ylim(-8, 10)
plt.xlim(490, 560)
plt.grid(True)
plt.savefig('inference_real_data_convolution_2.pdf', dpi=500)

In [None]:
# Plotting the posterior mean
samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)
mean_posterior = samples.mean(dim=0)
y_knots = np.zeros(config.N_knots)
y_knots[0] = config.max_G_0 + mean_posterior[2].numpy()
y_knots[-1] = config.max_G_0 + mean_posterior[-1].numpy()
y_knots[1] = config.max_G_1 + mean_posterior[2].numpy()
y_knots[-2] = config.max_G_1 + mean_posterior[-1].numpy()
y_knots[2:-2] = mean_posterior[2:].numpy()
#y_axis = gls_spline(x_knots, y_knots, x_axis)
#true_y = gls_spline(x_knots, y_true, x_axis)
#plt.plot(x_axis, true_y - np.min(true_y), color='red', linewidth=1.6, label='ground truth')
#plt.plot(x_axis, true_y - np.min(true_y), color='red')
dc = np.genfromtxt('../../smFE_data/deconvolution_new_new+.csv', delimiter=',')
y_knots_err = np.zeros((2, config.N_knots))

y_knots_err[:, 2:-2] = np.abs(np.quantile(samples.cpu().numpy(), [0.16, 0.84], axis=0)[:, 2:] - mean_posterior[2:].numpy())
y_axis = gls_spline(x_knots, y_knots, x_axis)
plt.plot(x_axis, y_axis, alpha=1, color='blue')
plt.plot(x_knots, y_knots, 'ob')
plt.errorbar(x_knots,  y_knots, yerr=y_knots_err, linestyle='', marker='o', color='blue', label='posterior mean')
plt.plot(-dc[:, 0]*0.93+531.0, dc[:, 1]+2.7, color='red', label='deconvolution')

#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l$ = {k}')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
plt.legend(loc='upper left', fontsize=14)
plt.ylim(-8, 10)
plt.xlim(490, 560)
plt.grid(True)
#plt.savefig('posterior5_mean.pdf', dpi=500)

In [None]:

y_knots = np.zeros(config.N_knots)
y_knots[0] = config.max_G_0
y_knots[-1] = config.max_G_0 
y_knots[1] = config.max_G_1 
y_knots[-2] = config.max_G_1 
y_knots[2:-2] = np.array([-3.2522, -1.8884, -1.5832, -0.5130,  0.4640,  1.1342, 3.0870,  4.9160,  4.7707,  -1.1000, -4.8352])
dc = np.genfromtxt('../../smFE_data/deconvolution_new_new+.csv', delimiter=',')
y_axis = gls_spline(x_knots, y_knots, x_axis)
plt.plot(x_axis, y_axis, alpha=1, color='blue')
plt.plot(x_knots, y_knots, 'ob')
plt.plot(-dc[:, 0]*0.93+531.0, dc[:, 1]+2.7, color='red', label='deconvolution')

#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l$ = {k}')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
plt.legend(loc='upper left', fontsize=14)
plt.ylim(-8, 10)
plt.xlim(490, 560)
plt.grid(True)
plt.savefig('fake_experimental_free_energy.pdf', dpi=400)

In [None]:
with open(f'../data/posteriors/sequential_posterior_norm_0_obs0_round=24.pkl', 'rb') as handle:
    posterior = pickle.load(handle)

posterior.set_default_x(torch.load('../data/observations/real_obs0.pt'))

In [None]:
samples = posterior.sample((100000,)).cpu()

In [None]:
from brownian_integrator import brownian_integrator

In [None]:
torch.mean(samples.cpu(), dim=0)

In [None]:
y_knots = np.zeros(config.N_knots)
y_knots[0] = config.max_G_0
y_knots[-1] = config.max_G_0 
y_knots[1] = config.max_G_1 
y_knots[-2] = config.max_G_1 
y_knots[2:-2] = np.array([-3.2522, -1.8884, -1.5832, -0.5130,  0.4640,  1.1342, 3.0870,  4.9160,  4.7707,  -1.1000, -4.8352])


T = 10_000_000
dt = 0.1 # \mu s
N = round(T / dt)
saving_freq = 10
Dx = 0.38 # nm^2 / \mu s
Dq = (10 ** mean_posterior[0].item())
k = 10 ** mean_posterior[1].item()

In [None]:
qs = brownian_integrator(
        x0=510,
        q0=510,
        Dx=Dx,
        Dq=Dq,
        x_knots=x_knots,
        y_knots=y_knots,
        k=k,
        N=N,
        dt=dt,
        fs=saving_freq
)

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True, gridspec_kw={'width_ratios': [4, 1]}, figsize=(10, 5))
axes[0].scatter(np.linspace(0, 2, len(qs)), qs, s=0.05, alpha=0.5)
axes[0].set_xlabel('Time [$s$]', fontsize=18)
axes[0].set_ylabel('Extension [nm]', fontsize=18)
axes[0].set_ylim(495, 555)
counts, bins, _ = axes[1].hist(qs, bins=100, density=True, orientation='horizontal')
axes2 = axes[1].twiny()
pmf = -np.log(counts) - min(-np.log(counts))
axes2.plot(pmf, bins[1:], color='black', label='PMF')
axes2.legend()
axes2.set_xlim(-0.5, 10)
#plt.savefig('fake_experimental_trajectory.jpg', dpi=400)

In [None]:
qs_real = data[:, 0]
fig, axes = plt.subplots(1, 2, sharey=True, gridspec_kw={'width_ratios': [4, 1]}, figsize=(10, 5))
axes[0].scatter(np.linspace(0, 2, len(qs_real)), qs_real, s=0.05, alpha=0.5)
axes[0].set_xlabel('Time [$s$]', fontsize=18)
axes[0].set_ylabel('Extension [nm]', fontsize=18)
axes[0].set_ylim(495, 555)
counts, bins, _ = axes[1].hist(qs_real, bins=100, density=True, orientation='horizontal')
axes2 = axes[1].twiny()
pmf = -np.log(counts) - min(-np.log(counts))
axes2.plot(pmf, bins[1:], color='black', label='PMF')
axes2.legend()
axes2.set_xlim(-0.5, 10)
plt.savefig('experimental_trajectory.jpg', dpi=400)

In [None]:
with open(f'../data/posteriors/sequential_posterior_norm_0_obs0_round=24.pkl', 'rb') as handle:
    posterior = pickle.load(handle)

posterior.set_default_x(torch.load('../data/observations/real_obs0.pt'))

In [None]:
simulated_trajectories = np.zeros((2_000_000, 20))

for i in range(1):
        sample = posterior.sample().cpu()[0]

        x_knots = np.linspace(config.min_x, config.max_x, config.N_knots)
        y_knots = np.zeros(config.N_knots)
        y_knots[0] = config.max_G_0 + sample[2].numpy()
        y_knots[-1] = config.max_G_0 + sample[-1].numpy()
        y_knots[1] = config.max_G_1 + sample[2].numpy()
        y_knots[-2] = config.max_G_1 + sample[-1].numpy()
        y_knots[2:-2] = sample[2:].numpy()

        T = 2_000_000
        dt = 0.1 # \mu s
        N = round(T / dt)
        saving_freq = 10
        Dx = 0.38 # nm^2 / \mu s

        Dq = Dx * (10 ** sample[0].item())
        k = 10 ** sample[1].item()

        qs = brownian_integrator(
                x0=510,
                q0=540,
                Dx=Dx,
                Dq=Dq,
                x_knots=x_knots,
                y_knots=y_knots,
                k=k,
                N=N,
                dt=dt,
                fs=saving_freq
        )
        simulated_trajectories[:, i] = qs


#np.save('posterior_trajectory_long.npy', qs)

In [None]:
np.save('simulated_experiment', simulated_trajectories)

In [None]:
fig, axes = plt.subplots(1, 2, sharey=True, gridspec_kw={'width_ratios': [4, 1]}, figsize=(10, 5))
axes[0].scatter(np.linspace(0, 2, len(qs)), qs, s=0.05, alpha=0.5)
axes[0].set_xlabel('Time [$s$]', fontsize=18)
axes[0].set_ylabel('Extension [nm]', fontsize=18)
axes[0].set_ylim(495, 555)
counts, bins, _ = axes[1].hist(qs, bins=100, density=True, orientation='horizontal')
axes2 = axes[1].twiny()
pmf = -np.log(counts) - min(-np.log(counts))
axes2.plot(pmf, bins[1:], color='black', label='PMF')
axes2.legend()
axes2.set_xlim(-0.5, 10)
#plt.savefig('posterior_sample_trajectory_1.jpg', dpi=400)

In [None]:
for i in range(20):
    sample = posterior.sample()[0].cpu()

    x_knots = np.linspace(config.min_x, config.max_x, config.N_knots)
    y_knots = np.zeros(config.N_knots)
    y_knots[0] = config.max_G_0 + sample[2].numpy()
    y_knots[-1] = config.max_G_0 + sample[-1].numpy()
    y_knots[1] = config.max_G_1 + sample[2].numpy()
    y_knots[-2] = config.max_G_1 + sample[-1].numpy()
    y_knots[2:-2] = sample[2:].numpy()

    T = 2_000_000
    dt = 0.1 # \mu s
    N = round(T / dt)
    saving_freq = 10
    Dx = 0.38 # nm^2 / \mu s
    Dq = Dx
    k = 0.08

    Dq = Dx * (10 ** sample[0].item())
    k = 10 ** sample[1].item()

    qs = brownian_integrator(
            x0=510,
            q0=540,
            Dx=Dx,
            Dq=Dq,
            x_knots=x_knots,
            y_knots=y_knots,
            k=k,
            N=N,
            dt=dt,
            fs=saving_freq
    )
    counts_sim, bins_sim= np.histogram(qs, bins=np.linspace(505, 545, 100), density=True)
    if i == 0:
        plt.plot(bins_sim[1:], -np.log(counts_sim) - np.mean(-np.log(counts_sim)), color='gray', alpha=0.5, label='Posterior samples') 
    else:
        plt.plot(bins_sim[1:], -np.log(counts_sim) - np.mean(-np.log(counts_sim)), color='gray', alpha=0.5)
for i in range(1):
    counts_real, bins_real = np.histogram(data[:, 0], bins=np.linspace(505, 545, 100), density=True)
    plt.plot(bins_real[1:], -np.log(counts_real)  - np.mean(-np.log(counts_real)), color='red', alpha=1, label='Experimental')
plt.legend()
plt.ylim(-4, 7)
plt.ylabel(r'$G(q)$', fontsize=18)
plt.xlabel(r'$q$', fontsize=18)
plt.savefig('mean_aligned_ppc.pdf', dpi=300)

In [None]:
from sbi.utils import MultipleIndependent
import torch.distributions as dists

In [None]:
class SplinePrior(MultipleIndependent):
    def __init__(self, dists):
        super().__init__(dists, validate_args=None, arg_constraints={})

    def sample(self, sample_shape=torch.Size([])):
        samples = super().sample(sample_shape)
        if sample_shape == torch.Size():
            samples[2:] = samples[2:] - torch.mean(samples[2:]).reshape(-1, 1)
        else:
            samples[:, 2:] = samples[:, 2:] - torch.mean(samples[:, 2:], dim=1).reshape(-1, 1)
        return samples

In [None]:
device = 'cpu' 

priors = [
    dists.Uniform(torch.tensor([-2.], device=device), torch.tensor([1.], device=device)),
    dists.Uniform(torch.tensor([-2.], device=device), torch.tensor([0.], device=device)),
    *(dists.Normal(torch.tensor([0.], device=device), torch.tensor([3.], device=device)) for i in range(11))
]

prior = SplinePrior(priors)

In [None]:
for idx, difference in enumerate(prior.sample()[2:]):
    spline_nodes[idx+1] = spline_nodes[idx] + difference

spline_nodes = spline_nodes - torch.mean(spline_nodes)

In [None]:

x_knots = np.linspace(config.min_x, config.max_x, config.N_knots)
x_axis = np.linspace(config.min_x, config.max_x, 1000)
y_knots = 30 * np.ones(config.N_knots)
y_knots[1] = y_knots[-2] = 5

for i in range(10):
    sample = prior.sample().numpy()
    y_knots = np.zeros(15)
    y_knots[0] = 30 + sample[2]
    y_knots[-1] = 30 + sample[-1]
    y_knots[1] = 5 + sample[2]
    y_knots[-2] = 5 + sample[-1]
    y_knots[2: -2] = sample[2:]

    y_axis = gls_spline(x_knots, y_knots, x_axis)
    #plt.plot(x_knots, y_knots, 'ob')
    plt.plot(x_axis, y_axis, alpha=0.1, color='blue')

#plt.title(f'log(Dq/Dx) = {D}, $\kappa_l$ = {k}')
plt.xlabel(r'Molecular extension x', fontsize=18)
plt.ylabel(r'$G_0(x)$', fontsize=18)
plt.legend(loc='upper left', fontsize=14)
plt.ylim(-7, 7)
plt.xlim(490, 560)
plt.grid(True)

In [None]:
y_knots = np.zeros(15)
y_knots[0] = y_knots[-1] = 30
y_knots[1] = y_knots[-2] = 5

In [None]:
config.N_knots_prior

In [None]:
len(prior.sample()[2:])

In [None]:
assert len(prior.sample()[2:]) == config.N_knots_prior, 'Prior has wrong dimension!'
for idx, difference in enumerate(prior.sample()[2:].numpy()):
    y_knots[2 + idx + 1] = y_knots[2 + idx] + difference
    y_knots[2: -2] -= y_knots[2: -2].mean()

In [None]:
y_knots[2: -2].mean()

In [None]:
y_knots

In [None]:
_ = analysis.marginal_plot(samples.cpu(), points_colors='r', figsize=(32, 2))