In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from cryo_sbi.inference.models import build_models
from cryo_sbi import CryoEmSimulator
from cryo_sbi.inference import priors

In [None]:
file_name = '23_03_09_results'    # File name 
data_dir = "../experiments/benchmark_hsp90/results/raw_results/"
plot_dir = "../experiments/benchmark_hsp90/results/plots/"
config_dir = "../experiments/benchmark_hsp90/"
num_samples_stats = 20000           # Number of simulations for computing posterior stats
num_samples_SBC = 10000             # Number of simulations for SBC
num_posterior_samples_SBC = 4096    # Number of posterior samples for each SBC simulation
num_samples_posterior = 50000       # Number of samples to draw from posterior
batch_size_sampling = 100           # Batch size for sampling posterior
num_workers = 24                    # Number of CPU cores
device = 'cuda'                     # Device for computations
save_data = False
save_figures = False

In [None]:
results_fig2b = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_04_03_final_posterior_snr01_examples.pt")

In [None]:
fig , axes = plt.subplots(1, 3, figsize=(10, 3))
idxs = [5, 11, 18]
inset_x = [0.5, 0.05, 0.05]
for idx, ax in enumerate(axes):
    ax.hist(
        results_fig2b['posterior_samples'][:, idxs[idx]].flatten().numpy(),
        bins=np.arange(0, 20, 0.3),
        histtype="step",
        color="blue",
        linewidth=1.1
    )
    ax.set_yticks([])
    ax.set_xlabel(r'$\phi$', fontsize=18)
    ax.axvline(results_fig2b['indices'][idxs[idx]], color='red')
    ax_inset = ax.inset_axes((inset_x[idx], 0.30, 0.4, 0.90))
    ax_inset.set_xticks([])
    ax_inset.set_yticks([])
    ax_inset.imshow(results_fig2b["images"][idxs[idx]])

if save_figures:
    plt.savefig('../data/trained_posteriors/benchmark_hsp90/results/plots/fig1_b.pdf', dpi=400, bbox_inches='tight')

In [None]:
del results_fig2b

In [None]:
results_fig2c = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_04_03_final_posteriorstats_training.pt")

In [None]:
mean_distance, confidence_widths = results_fig2c

In [None]:
mean_distance = (results_fig2c['posterior_samples'].mean(dim=0) - results_fig2c['indices'].reshape(-1)).numpy()
posterior_quantiles = np.quantile(results_fig2c['posterior_samples'].numpy(), [0.025, 0.975], axis=0)
confidence_widths = (posterior_quantiles[1] - posterior_quantiles[0]).flatten()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
_ = ax1.hist(mean_distance, bins=np.arange(-10, 10, 0.2), density=True)
_ = ax2.hist(confidence_widths, bins=np.arange(0, 20, 0.5), density=True)
ax1.set_xlabel(r'$\overline{\phi} - \phi_{True}$', fontsize=13)
ax2.set_xlabel(r'Width of 3$\sigma$-confidence-intervall', fontsize=13)
ax1.set_yticklabels([])
ax2.set_yticklabels([])
if save_figures:
    fig.savefig('../data/trained_posteriors/benchmark_hsp90/results/plots/fig1_c.pdf')

In [None]:
del results_fig2c

In [None]:
labels = [f'{snr:.2f}' for snr in np.logspace(np.log10(0.5), -2, 5)]
fig, axes = plt.subplots(1, 4, figsize=(17, 5), sharex=True, sharey=True)
for idx, defoc in enumerate(np.linspace(0.5, 2, 4)):
    data = []
    for snr in np.logspace(np.log10(0.5), -2, 5):
        mean_diff, confidence = torch.load(f'../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posteriorstats_snr={round(snr, 2)}_defocus={defoc}.pt')
        data.append(confidence)
    axes[idx].violinplot(data, vert=False, showextrema=False)
    if defoc == 0.5:
        axes[idx].set_yticks([1, 2, 3, 4, 5])
        axes[idx].set_yticklabels(labels)
    axes[idx].set_title(f'Defocus = {defoc}', fontsize=14)
#axes[0].set_xlabel(r'$3\sigma$-confidence width', fontsize=14)
axes[0].set_xlabel(r'Width of 3$\sigma$-confidence-intervall', fontsize=14)
axes[0].set_ylabel('SNR', fontsize=14)
#fig.savefig('../experiments/benchmark_hsp90/results/plots/DEFOCUS_SNR_VIOLIN_mean.pdf', dpi=500)

In [None]:
labels = [f'{snr:.2f}' for snr in np.logspace(np.log10(0.5), -2, 5)]
fig, axes = plt.subplots(1, 4, figsize=(17, 5), sharex=True, sharey=True)
for idx, defoc in enumerate(np.linspace(0.5, 2, 4)):
    data = []
    for snr in np.logspace(np.log10(0.5), -2, 5):
        mean_diff, confidence = torch.load(f'../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posteriorstats_snr={round(snr, 2)}_defocus={defoc}.pt')
        data.append(mean_diff)
    axes[idx].violinplot(data, vert=False, showextrema=False)
    if defoc == 0.5:
        axes[idx].set_yticks([1, 2, 3, 4, 5])
        axes[idx].set_yticklabels(labels)
    axes[idx].set_title(f'Defocus = {defoc}', fontsize=14)
#axes[0].set_xlabel(r'$3\sigma$-confidence width', fontsize=14)
axes[0].set_xlabel(r'$\overline{\phi} - \phi_{True}$', fontsize=14)
axes[0].set_ylabel('SNR', fontsize=14)
#fig.savefig('../experiments/benchmark_hsp90/results/plots/DEFOCUS_SNR_VIOLIN_mean.pdf', dpi=500)

In [None]:
results_fig2d = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_04_03_final_posterior_quats.pt")
quats = np.load('quaternion_list.npy')

In [None]:
mean_distance = (results_fig2d['posterior_samples'].mean(dim=0) - results_fig2d['indices'].reshape(-1)).numpy()
posterior_quantiles = np.quantile(results_fig2d['posterior_samples'].numpy(), [0.025, 0.975], axis=0)
confidence_widths = (posterior_quantiles[1] - posterior_quantiles[0]).flatten()

In [None]:
quats = np.load('quaternion_list.npy')

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy.spatial.transform import Rotation

# Create a sphere
r = 1
pi = np.pi
cos = np.cos
sin = np.sin
phi, theta = np.mgrid[0.0:pi:100j, 0.0:2.0*pi:100j]
x = r*sin(phi)*cos(theta)
y = r*sin(phi)*sin(theta)
z = r*cos(phi)

unit_vecotr = np.array([0, 0, 1])
points = []
for i in range(len(quats)):
    rot_mat = Rotation.from_quat(quats[i]).as_matrix()
    coord = np.matmul(rot_mat, unit_vecotr)
    points.append(coord)
points = np.array(points)

xx = points[:, 0]
yy = points[:, 1]
zz = points[:, 2]

#Set colours and render
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

#ax.plot_surface(
#    x, y, z,  rstride=1, cstride=1, color='c', alpha=0.3, linewidth=0)

im = ax.scatter(xx, yy, zz, s=0.1)

ax.set_xlabel('x')
ax.set_ylabel('y')
#ax.set_zlabel('z')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

ax.azim = -40
ax.elev = 30
ax.dist = 10

fig.colorbar(im)

ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
ax.set_zlim([-1,1])
ax.set_aspect("auto")
plt.tight_layout()
plt.show()
#plt.savefig("rotation_posterior_widths_snr001_widres50_128x128.pdf", dpi=600)

In [None]:
def asSpherical(xyz):
    x       = xyz[0]
    y       = xyz[1]
    z       = xyz[2]
    r       =  np.sqrt(x*x + y*y + z*z)
    theta   =  np.arccos(z/r)
    phi     =  np.arctan2(y,x)
    return [r,theta,phi]

coordinates = np.array([asSpherical(point)[1:] for point in points])

In [None]:
unit_vecotr = np.array([0, 0, 60])
points = []
for i in np.where(confidence_widths < 2.4)[0]:
    rot_mat = Rotation.from_quat(quats[i]).as_matrix()
    coord = np.matmul(rot_mat, unit_vecotr)
    points.append(coord)
points = np.array(points)

x1 = points[:, 0]
y1 = points[:, 1]
z1 = points[:, 2]


In [None]:
xx, yy, zz = np.load("../data/protein_models/hsp90_models.npy")[0, :][19]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.azim = 190
ax.elev = 10
#ax.plot_surface(
#    x, y, z,  rstride=1, cstride=1, color='c', alpha=0.3, linewidth=0)

im = ax.scatter(xx, yy, zz, s=5)
#im = ax.scatter(x1, y1, z1, s=1)


In [None]:
xx, yy, zz = np.load("../data/protein_models/hsp90_models.npy")[:, 0][19]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.azim = 190
ax.elev = 10
#ax.plot_surface(
#    x, y, z,  rstride=1, cstride=1, color='c', alpha=0.3, linewidth=0)

im = ax.scatter(xx, yy, zz, s=5)
#im = ax.scatter(x1, y1, z1, s=1)


In [None]:
xx, yy, zz = np.load("../data/protein_models/hsp90_models.npy")[0, 0]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.azim = 190
ax.elev = 10
#ax.plot_surface(
#    x, y, z,  rstride=1, cstride=1, color='c', alpha=0.3, linewidth=0)

im = ax.scatter(xx, yy, zz, s=5)

In [None]:
fig , ax = plt.subplots(1, 1, figsize=(7, 4))

scatter_plot = ax.scatter(coordinates[:, 1], coordinates[:, 0], s=0.5*confidence_widths, c=confidence_widths)
cbar = fig.colorbar(scatter_plot)
cbar.set_label(label=r'Width of 3$\sigma$-confidence-intervall',size=13)
cbar.ax.tick_params(labelsize=12)
ax.set_ylabel(r'$\theta$', fontsize=18)
ax.set_xlabel(r'$\varphi$', fontsize=18)
ax.set_yticks(
    ticks=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4, np.pi],
    labels=['0', r'$\frac{1}{4}\pi$', r'$\frac{1}{2}\pi$', r'$\frac{3}{4}\pi$', r'$\pi$'],
    fontsize=12
)
ax.set_xticks(
    ticks=[-np.pi, -np.pi/2 ,0, np.pi/2 ,np.pi],
    labels=[r'$-\pi$', r'$-\frac{1}{2}\pi$', r'0', r'$\frac{1}{2}\pi$', r'$\pi$'],
    fontsize=12
)

if save_figures:
    plt.savefig(f'../../cryo_sbi_experimental/data/trained_posteriors/benchmark_hsp90/results/plots/fig2d.pdf', dpi=400, bbox_inches='tight')

In [None]:
gaussian_posterior = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_gaussian_stats.pt")
colored_posterior = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_colored_stats.pt")
gradient_posterior = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_gradient_stats.pt")

In [None]:
samples = gaussian_posterior['posterior_samples']
indices = gaussian_posterior['indices']
mean_distance_gaussian = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths_gaussian = (posterior_quantiles[1] - posterior_quantiles[0]).flatten()

In [None]:
samples = colored_posterior['posterior_samples']
indices = colored_posterior['indices']
mean_distance_colored = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths_colored = (posterior_quantiles[1] - posterior_quantiles[0]).flatten()

In [None]:
samples = gradient_posterior['posterior_samples']
indices = gradient_posterior['indices']
mean_distance_gradient = (samples.mean(dim=0) - indices.reshape(-1)).numpy()
posterior_quantiles = np.quantile(samples.numpy(), [0.025, 0.975], axis=0)
confidence_widths_gradient = (posterior_quantiles[1] - posterior_quantiles[0]).flatten()

In [None]:
gaussian_sbc = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_gaussian_SBC.pt")
colored_sbc = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_colored_SBC.pt")
gradient_sbc = torch.load("../experiments/benchmark_hsp90/results/raw_data/23_03_21_final_posterior_gradient_SBC.pt")

levels = torch.stack([gaussian_sbc['levels'], colored_sbc['levels'], gradient_sbc['levels']], dim=1)
coverages = torch.stack([gaussian_sbc['coverages'], colored_sbc['coverages'], gradient_sbc['coverages']], dim=1)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

axes[0].hist(mean_distance_gaussian, bins=np.arange(-10, 10, 0.5), histtype='step', density=True, label='Gaussain', linewidth=2)
axes[0].hist(mean_distance_colored, bins=np.arange(-10, 10, 0.5), histtype='step', density=True, label='Colored', linewidth=2)
axes[0].hist(mean_distance_gradient, bins=np.arange(-10, 10, 0.5), histtype='step', density=True, label='Gradient', linewidth=2)
axes[0].set_xlabel(r'$\overline{\phi} - \phi_{True}$', fontsize=13)
axes[0].set_yticklabels([])

axes[1].hist(confidence_widths_gaussian, bins=np.arange(0, 20, 0.5), histtype='step', density=True, label='Gaussian', linewidth=2)
axes[1].hist(confidence_widths_colored, bins=np.arange(0, 20, 0.5), histtype='step', density=True, label='Colored', linewidth=2)
axes[1].hist(confidence_widths_gradient, bins=np.arange(0, 20, 0.5), histtype='step', density=True, label='Gradient', linewidth=2)
axes[1].set_xlabel(r'Width of 3$\sigma$-confidence intervall', fontsize=13)
axes[1].set_yticklabels([])
axes[1].legend(fontsize=13)

axes[2].plot(levels, coverages)
axes[2].plot([0, 1], [0, 1], linestyle="--", linewidth=2, color="black")
axes[2].set_xlabel('Credible level', fontsize=13)
axes[2].set_ylabel('Expected coverage', fontsize=13)

fig.savefig(f'{plot_dir}hsp90_noise_missspecification.pdf', dpi=400)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(10, 5))

axes[0, 0].set_title('Gaussian noise')
axes[0, 0].imshow(gaussian_posterior['images'][0])
axes[1, 0].hist(gaussian_posterior['posterior_samples'][:, 0], bins=np.arange(0, 20, 0.2), histtype='step')
axes[1, 0].axvline(gaussian_posterior['indices'][0], color='red')
axes[1, 0].set_yticks([])
axes[1, 0].set_xlabel(r'$\phi$', fontsize=13)

axes[0, 1].set_title('Colored noise')
axes[0, 1].imshow(colored_posterior['images'][13])
axes[1, 1].hist(colored_posterior['posterior_samples'][:, 13], bins=np.arange(0, 20, 0.2), histtype='step')
axes[1, 1].axvline(colored_posterior['indices'][13], color='red')
axes[1, 1].set_yticks([])
axes[1, 1].set_xlabel(r'$\phi$', fontsize=13)

axes[0, 2].set_title('Gradient noise')
axes[0, 2].imshow(gradient_posterior['images'][20])
axes[1, 2].hist(gradient_posterior['posterior_samples'][:, 20], bins=np.arange(0, 20, 0.2), histtype='step')
axes[1, 2].axvline(gradient_posterior['indices'][20], color='red')
axes[1, 2].set_yticks([])
axes[1, 2].set_xlabel(r'$\phi$', fontsize=13)

for ax in axes[0]:
    ax.set_yticks([])
    ax.set_xticks([])

#fig.savefig(f'{plot_dir}hsp90_noise_example_particles.pdf', dpi=400)