# COSMOS (Calculation Of Susceptibility through Multiple Orientation Sampling)
### MOSAIC: Multi-Orientation Sampling And Inversion reConstruction

In [None]:
# Run this codeblock to mount your Google Drive in Google Colab.
from google.colab import drive

drive.mount("/content/drive/")

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch

ROOT = "./"  # TODO: set your root directory here
# ROOT = '/content/drive/MyDrive/DS2_Project/MOSAIC' # example for Google Drive

ROOT = Path(ROOT)
# Check if the root directory exists
if not ROOT.exists():
    print(f"Root directory {ROOT} does not exist. Please check the path.")

In [None]:
img = np.load(ROOT / "img.npy")
img = torch.tensor(img, dtype=torch.float32)
plt.figure(figsize=(5, 5))
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

In [None]:
def dipole_kernel(
    matrix_size: tuple[int, int],
    voxel_size: tuple[float, float] = (1.0, 1.0),
    B0_dir: tuple[float, float] = (0.0, 1.0),
) -> torch.Tensor:
    y = np.arange(-matrix_size[1] / 2, matrix_size[1] / 2, 1)
    x = np.arange(-matrix_size[0] / 2, matrix_size[0] / 2, 1)
    Y, X = np.meshgrid(y, x)

    X = X / (matrix_size[0] * voxel_size[0])
    Y = Y / (matrix_size[1] * voxel_size[1])

    D = 1 / 3 - (X * B0_dir[0] + Y * B0_dir[1]) ** 2 / (X**2 + Y**2 + 1e-8)
    D = np.fft.fftshift(D)
    D = torch.tensor(D, dtype=torch.float32)
    return D

In [None]:
num_dirs = 5
dirs = torch.arange(0, 2 * torch.pi, 2 * torch.pi / num_dirs, dtype=torch.float32)
dirs = torch.stack([torch.cos(dirs), torch.sin(dirs)], dim=-1)

img_shape = img.shape

kernels = torch.stack(
    [
        dipole_kernel(
            matrix_size=(img_shape[1], img_shape[0]),
            voxel_size=(1.0, 1.0),
            B0_dir=(d[0].item(), d[1].item()),
        )
        for d in dirs
    ],
    dim=0,
)

img_k = torch.fft.fftn(img, dim=(-2, -1))
imgs = torch.stack(
    [torch.fft.ifftn(img_k * kernel, dim=(-2, -1)).real for kernel in kernels], dim=0
)

plt.figure(figsize=(5 * num_dirs, 5))
for i in range(num_dirs):
    plt.subplot(1, num_dirs, i + 1)
    plt.imshow(imgs[i], cmap="gray")
    plt.title(f"Direction {i + 1}")
    plt.axis("off")
plt.tight_layout()
plt.show()

In [None]:
imgs_k = torch.fft.fftn(imgs, dim=(-2, -1))
print("Shape of imgs_k:", imgs_k.shape)
recon_k = torch.zeros_like(img, dtype=torch.complex64)
H, W = img.shape[-2], img.shape[-1]
for i in range(H):
    for j in range(W):
        y = imgs_k[:, i, j]
        A = kernels[:, i, j].unsqueeze(1)
        A_inv = torch.linalg.pinv(A).squeeze().type(torch.complex64)
        recon_k[i, j] = torch.dot(A_inv, y)

recon = torch.fft.ifftn(recon_k, dim=(-2, -1)).real

In [None]:
vmax_label = np.percentile(img, 98) * 1.0
vmin = 0

plt.figure(figsize=(5, 5))
plt.imshow(recon, cmap="gray", vmin=vmin, vmax=vmax_label)
plt.title("Reconstructed Image")
plt.axis("off")
plt.show()