In [9]:
import numpy as np
import os
from scipy.stats import multivariate_normal
import nibabel as nib
from tqdm import tqdm

# Set volume shape
VOLUME_SHAPE = (90, 120, 120)

def generate_gaussian_structure(mean, cov, shape, threshold=0.002, noise_level=0.0001):
    """Generates a 3D Gaussian blob and its binary segmentation."""
    z, y, x = np.indices(shape)
    grid = np.stack((z, y, x), axis=-1).reshape(-1, 3)
    
    # Evaluate Gaussian
    rv = multivariate_normal(mean=mean, cov=cov)
    probs = rv.pdf(grid).reshape(shape)
    
    # Normalize
    probs /= np.max(probs)
    
    # Segmentation from the clean probability map
    seg = (np.abs(probs) > threshold).astype(np.uint8)
    
    # Add background noise
    noise = np.random.uniform(0, noise_level, shape)
    volume = probs + noise * (seg == 0)
    

    # Ensure volume values are within [0, 1]
    volume = np.clip(volume, 0, 1).astype(np.float32)
    
    return volume, seg

def save_nifti(volume, path):
    """Save a 3D volume as a NIfTI file."""
    affine = np.eye(4)
    img = nib.Nifti1Image(volume, affine)
    nib.save(img, path)
    
def jitter_covariance(base_cov, jitter_strength=0.1):
    """Applies small random jitter to the covariance matrix and projects to PSD."""
    noise = np.random.randn(*base_cov.shape) * jitter_strength
    noise = (noise + noise.T) / 2  # Symmetrize
    jittered_cov = base_cov + noise

    # Project to nearest PSD matrix via eigen-decomposition
    eigvals, eigvecs = np.linalg.eigh(jittered_cov)
    eigvals_clipped = np.clip(eigvals, a_min=1e-5, a_max=None)  # Avoid tiny or negative eigenvalues
    cov_psd = eigvecs @ np.diag(eigvals_clipped) @ eigvecs.T

    return cov_psd


def jitter_mean(base_mean, max_offset=3):
    """Jitters the mean within a small range around the center."""
    offset = np.random.randint(-max_offset, max_offset + 1, size=3)
    return (np.array(base_mean) + offset).tolist()




def perturb_covariance(base_cov, eigval_jitter=0.1, angle_jitter_deg=5):
    """
    Jitter eigenvalues slightly and apply a small random rotation.
    Ensures the covariance remains PSD and shape-consistent.
    """
    # Eigendecompose base covariance
    eigvals, eigvecs = np.linalg.eigh(base_cov)
    
    # Jitter eigenvalues (scale them by small factor)
    scale_factors = 1 + np.random.uniform(-eigval_jitter, eigval_jitter, size=3)
    new_eigvals = eigvals * scale_factors

    # Generate a small random rotation matrix (Rodrigues' rotation)
    angle_rad = np.deg2rad(angle_jitter_deg)
    random_axis = np.random.randn(3)
    random_axis /= np.linalg.norm(random_axis)
    theta = np.random.uniform(-angle_rad, angle_rad)

    # Rodrigues' formula
    K = np.array([
        [0, -random_axis[2], random_axis[1]],
        [random_axis[2], 0, -random_axis[0]],
        [-random_axis[1], random_axis[0], 0]
    ])
    R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K)

    # Combine new eigvals and rotation
    jittered_cov = R @ eigvecs @ np.diag(new_eigvals) @ eigvecs.T @ R.T
    return jittered_cov




def generate_dataset(mean, cov, n_samples, output_dir, prefix='dataset'):
    os.makedirs(output_dir, exist_ok=True)
    for i in tqdm(range(n_samples), desc=f"Generating {prefix}"):
        jittered_mean = jitter_mean(mean)
        jittered_cov = perturb_covariance(cov, eigval_jitter=0.1, angle_jitter_deg=5)


        volume, seg = generate_gaussian_structure(jittered_mean, jittered_cov, VOLUME_SHAPE)
        
        vol_mean = "_".join(map(str, map(int, jittered_mean)))
        vol_cov = "_".join(map(lambda x: f"{x:.2f}", jittered_cov.flatten()))

        subj_dir = os.path.join(output_dir, f"{i:02d}")
        os.makedirs(subj_dir, exist_ok=True)
        vol_path = os.path.join(subj_dir, f"{i:02d}_vol_mean_{vol_mean}_cov_{vol_cov}.nii.gz")
        seg_path = os.path.join(subj_dir, f"{i:02d}_seg_mean_{vol_mean}_cov_{vol_cov}.nii.gz")

        save_nifti(volume, vol_path)
        save_nifti(seg, seg_path)



if __name__ == "__main__":
    # Example: define two datasets with different Gaussian distributions
    output_base = "./synthetic_mri_datasets"
    
    # Long along x=y, thin in z
    length = 4000  # adjust for how long you want the diagonal
    thickness = 2  # adjust for how thin you want it

    # Eigenvector along x=y,0
    v1 = np.array([1, 1, 0]) / np.sqrt(2)
    # Orthogonal in-plane
    v2 = np.array([-1, 1, 0]) / np.sqrt(2)
    # Orthogonal to plane
    v3 = np.array([0, 0, 1])

    # Rotation matrix
    R = np.stack([v1, v2, v3], axis=1)
    eigvals = np.array([length, thickness, thickness])
    cov_diag = np.diag(eigvals)
    cov_xy_diag = R @ cov_diag @ R.T

    datasets = {
        #"dataset_A": {
        #    "mean": [s // 2 for s in VOLUME_SHAPE],  # center of the volume
        #    "cov": np.diag([4, 180, 4]),  # long ellipsoid in y and x, thin in z
        #    "n_samples": 10
        #},
        #"dataset_B": {
        #    "mean": [s // 2 + o for s, o in zip(VOLUME_SHAPE, [-3, 2, 4])],  # slightly shifted
        #    "cov": np.array([[100, 2, 0], [12, 120, 0], [0, 0, 10]]),  # long ellipsoid with minor tilt
        #    "n_samples": 10
        #}, 
        #"horizontal_z": {
        ## Centered, long in y, thin in x and z
        #"mean": [VOLUME_SHAPE[0] // 2, VOLUME_SHAPE[1] // 2, VOLUME_SHAPE[2] // 2],
        #"cov": np.diag([2, 8000, 2]),  # long in y
        #"n_samples": 10
    #},
        "diagonal_xy_fixed_z": {
        "mean": [VOLUME_SHAPE[0] // 2, VOLUME_SHAPE[1] // 2, VOLUME_SHAPE[2] // 2],
        "cov": cov_xy_diag,
        "n_samples": 10
    },
        "diagonal_xyz": {
        # Main diagonal through the cube
        "mean": [VOLUME_SHAPE[0] // 2, VOLUME_SHAPE[1] // 2, VOLUME_SHAPE[2] // 2],
        "cov": (
            # Constructed to be long along [1,1,1], thin otherwise
            (np.array([
                [1, 1, 1],
                [1, 1, 1],
                [1, 1, 1]
            ]) * 4000 / 3) + np.diag([2, 2, 2])
        ),
        "n_samples": 10
    },
        "vertical_z": {
        # Elongated along the depth axis (z)
        "mean": [VOLUME_SHAPE[0] // 2, VOLUME_SHAPE[1] // 2, VOLUME_SHAPE[2] // 2],
        "cov": np.diag([2, 2, 4000]),  # Long in z, thin in x and y
        "n_samples": 10
},
        "vertical_x": {
        # Elongated along the x axis
        "mean": [VOLUME_SHAPE[0] // 2, VOLUME_SHAPE[1] // 2, VOLUME_SHAPE[2] // 2],
        "cov": np.diag([4000, 2, 2]),  # Long in x, thin in y and z
        "n_samples": 10
}


        
}

    for name, params in datasets.items():
        print(f"Generating {name}...")
        print("Saving to:", os.path.join(output_base, name))
        generate_dataset(
            mean=params["mean"],
            cov=params["cov"],
            n_samples=params["n_samples"],
            output_dir=os.path.join(output_base, name),
            prefix=name
        )


Generating diagonal_xy_fixed_z...
Saving to: ./synthetic_mri_datasets/diagonal_xy_fixed_z


Generating diagonal_xy_fixed_z: 100%|██████████| 10/10 [00:03<00:00,  2.77it/s]


Generating diagonal_xyz...
Saving to: ./synthetic_mri_datasets/diagonal_xyz


Generating diagonal_xyz: 100%|██████████| 10/10 [00:03<00:00,  2.81it/s]


Generating vertical_z...
Saving to: ./synthetic_mri_datasets/vertical_z


Generating vertical_z: 100%|██████████| 10/10 [00:03<00:00,  2.77it/s]


Generating vertical_x...
Saving to: ./synthetic_mri_datasets/vertical_x


Generating vertical_x: 100%|██████████| 10/10 [00:03<00:00,  2.79it/s]
