In [None]:
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np

In [None]:
from cryo_sbi.utils.image_utils import (
    Mask,
    MRCtoTensor,
    LowPassFilter,
    FourierDownSample,
    NormalizeIndividual,
    circular_mask,
)

In [None]:
final_size = 128
transform = transforms.Compose(
    [
        MRCtoTensor(),
        FourierDownSample(256, final_size),
        LowPassFilter(128, 15),
        NormalizeIndividual(),
    ]
)

In [None]:
particles_transfomed = []
for i in range(1, 70):
    if i < 10:
        img_file = f"../../6wxb/particles/particles_0{i}.mrc"
    else:
        img_file = f"../../6wxb/particles/particles_{i}.mrc"
    tmp_images = transform(img_file)
    particles_transfomed.append(tmp_images)
particles_transfomed = torch.cat(particles_transfomed, dim=0)

In [None]:
plt.imshow(particles_transfomed[0], cmap="binary", vmin=-4, vmax=4)

In [None]:
images_fft = torch.fft.fft2(particles_transfomed)
images_fft = torch.fft.fftshift(images_fft, dim=(-2, -1))

power_spectrum = torch.abs(images_fft) ** 2

In [None]:
power_spectrum.shape

In [None]:
def radial_profile(data, center):
    """Compute the radial profile of a 2D image.
    Code from https://stackoverflow.com/questions/21242011/most-efficient-way-to-calculate-radial-profile

    Args:
        data (torch.Tensor): Image of shape (n_pixels, n_pixels).
        center (tuple): Center of the image.

    Returns:
        radialprofile (torch.Tensor): Radial profile of the image.
    """

    y, x = torch.meshgrid(torch.arange(data.shape[0]), torch.arange(data.shape[1]))
    r = torch.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
    r = r.to(torch.int)
    tbin = torch.bincount(r.ravel(), data.ravel())
    nr = torch.bincount(r.ravel())
    radialprofile = tbin / nr
    return radialprofile

In [None]:
length_rotation = round(np.sqrt(2 * (power_spectrum.shape[-1] // 2) ** 2))

In [None]:
rotational_avgs = torch.zeros(len(power_spectrum), length_rotation + 1)

In [None]:
length_rotation = round(np.sqrt(2 * (power_spectrum.shape[-1] // 2) ** 2))
rotational_avgs = torch.zeros(len(power_spectrum), length_rotation)
for i in range(len(power_spectrum)):
    rad_profile = radial_profile(
        power_spectrum[i],
        (power_spectrum.shape[-1] // 2, power_spectrum.shape[-1] // 2),
    )
    rotational_avgs[i, :] = rad_profile

In [None]:
plt.plot(rotational_avgs.mean(dim=0), markersize=2, marker="o", linestyle="None")
# plt.plot(rotational_avgs.max(axis=0), markersize=2, marker='o', linestyle='None')
labels = map(lambda x: f"{x:.1f}", np.arange(0, length_rotation, 40) * 1.06)
plt.xticks(ticks=np.arange(0, length_rotation, 40), labels=labels)
plt.xlabel(r"Radius ($\AA$)")
plt.ylabel("Averge intensity")
plt.ylim(0, 100000)

In [None]:
rotational_avgs.shape

In [None]:
final_size = 128
transform = transforms.Compose(
    [
        MRCtoTensor(),
    ]
)

In [None]:
particles_transfomed.shape

In [None]:
particles_transfomed = transform(
    "/home/dingeldein/Desktop/FoilHole_24136450_Data_24136362_24136364_20200224_020919_Fractions_patch_aligned_doseweighted.mrc"
)

In [None]:
particles_transfomed = (
    particles_transfomed - particles_transfomed.mean()
) / particles_transfomed.std()

In [None]:
power_spectrum = torch.fft.fftshift(torch.fft.fft2(particles_transfomed)).abs() ** 2

In [None]:
plt.plot(
    radial_profile(
        power_spectrum,
        (particles_transfomed.shape[-1] // 2, particles_transfomed.shape[-1] // 2),
    )
)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
axes[0].imshow(particles_transfomed, cmap="binary", vmin=-1, vmax=1)
axes[1].plot(
    radial_profile(
        power_spectrum,
        (particles_transfomed.shape[-1] // 2, particles_transfomed.shape[-1] // 2),
    )
)
axes[1].set_xlabel(r"Radius ($\AA$)")
length_rotation = round(np.sqrt(2 * (power_spectrum.shape[-1] // 2) ** 2))
labels = map(lambda x: f"{x:.0f}", np.arange(0, length_rotation, 800) * 1)
axes[1].set_xticks(ticks=np.arange(0, length_rotation, 800), labels=labels)
plt.xlabel(r"Radius ($Pixel$)")

In [None]:
for i in range(3):
    plt.imshow(
        (torch.fft.fftshift(torch.fft.fft2(particles_transfomed[i])) ** 2).real,
        cmap="binary",
        vmin=-1e6,
        vmax=1e6,
    )
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
particles_transfomed.shape

In [None]:
for i in range(3):
    plt.imshow(particles_transfomed[0], cmap="binary")
    plt.show()

In [None]:
particles = torch.load("simulated_particles_5000.mrc")

In [None]:
plt.imshow(particles[214], cmap="binary")

In [None]:
lpf = LowPassFilter(128, 15)
plt.imshow(lpf(particles[214]), cmap="binary")

In [None]:
images_fft = torch.fft.fft2(particles)
images_fft = torch.fft.fftshift(images_fft, dim=(-2, -1))
power_spectrum = torch.abs(images_fft) ** 2

length_rotation = round(np.sqrt(2 * (power_spectrum.shape[-1] // 2) ** 2))
rotational_avgs = torch.zeros(len(power_spectrum), length_rotation)
for i in range(len(power_spectrum)):
    rad_profile = radial_profile(
        power_spectrum[i],
        (power_spectrum.shape[-1] // 2, power_spectrum.shape[-1] // 2),
    )
    rotational_avgs[i, :] = rad_profile

In [None]:
plt.plot(rotational_avgs.mean(dim=0), markersize=2, marker="o", linestyle="None")

In [None]:
plt.imshow(power_spectrum[10], cmap="binary")