In [1]:
import os
import json
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import albumentations as A
from scipy.spatial.transform import Rotation as R


def get_obb_corners(center, size, angle_2d):
    """
    Compute the four corner coordinates of an oriented bounding box (OBB) in 2D
    given the box center, size, and rotation angle.
    """
    half_size = size / 2.0
    corners = np.array([
        [-half_size[0], -half_size[1]],
        [ half_size[0], -half_size[1]],
        [ half_size[0],  half_size[1]],
        [-half_size[0],  half_size[1]]
    ])
    angle_rad = np.deg2rad(angle_2d)
    R_mat = np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
                      [np.sin(angle_rad),  np.cos(angle_rad)]])
    rotated = corners @ R_mat.T + center[:2]
    return rotated

def augment_image(image, angle, translation=(0, 0), pivot=None):
    """
    Rotate the image by a given angle (in degrees) about a specified pivot point
    and apply a translation (tx, ty). If pivot is None, the image center (w/2, h/2) is used.
    Note: cv2 uses (x, y) ordering, i.e., (column, row).
    """
    h, w = image.shape
    if pivot is None:
        pivot = (w / 2, h / 2)
    M = cv2.getRotationMatrix2D(pivot, angle, 1.0)
    M[0, 2] += translation[0]
    M[1, 2] += translation[1]
    augmented = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return augmented

def augment_and_save_data(subject_ids, source_base_path_template, output_base_path, box_info_path, num_augmentations=5):
    """
    Apply data augmentation (rotation and translation) to the localizer images.
    The rotation pivot is set to the box center. The augmented images and the updated
    box information are then saved.
    """
    os.makedirs(output_base_path, exist_ok=True)
    with open(box_info_path, 'r') as f:
        box_info = json.load(f)
    combined_box_info = {}
    augmented_id = 5000

    for subject_id in subject_ids:
        input_folder = source_base_path_template.format(subject_id)
        subject_output_folder = os.path.join(output_base_path, subject_id)
        os.makedirs(subject_output_folder, exist_ok=True)

        # Original box information from JSON:
        # center: [X, Y, Z] (3D)
        # size: [sx, sy, sz]
        # angles: [angle_sag, angle_cor, angle_axi] (degrees)
        original_data = box_info[subject_id]
        original_center = np.array(original_data["center"], dtype=np.float32)  # shape (3,)
        size = np.array(original_data["size"], dtype=np.float32)
        angles = np.array(original_data["angles"], dtype=np.float32)

        # Localizer files are assumed to be in the order: 0: Sagittal, 1: Coronal, 2: Axial (already in RAS, canonical)
        localizer_files = sorted([os.path.join(input_folder, f) 
                                  for f in os.listdir(input_folder) if f.endswith(".nii.gz")])
        if len(localizer_files) < 3:
            print(f"Subject {subject_id}: Fewer than 3 localizer files are available.")
            continue
        sagittal = np.squeeze(nib.as_closest_canonical(nib.load(localizer_files[0])).get_fdata())
        coronal  = np.squeeze(nib.as_closest_canonical(nib.load(localizer_files[1])).get_fdata())
        axial    = np.squeeze(nib.as_closest_canonical(nib.load(localizer_files[2])).get_fdata())

        # Save original images (in the subject folder)
        nib.save(nib.Nifti1Image(sagittal.astype(np.float32), np.eye(4)),
                 os.path.join(subject_output_folder, f"{subject_id}_localizer_00001.nii.gz"))
        nib.save(nib.Nifti1Image(coronal.astype(np.float32), np.eye(4)),
                 os.path.join(subject_output_folder, f"{subject_id}_localizer_00002.nii.gz"))
        nib.save(nib.Nifti1Image(axial.astype(np.float32), np.eye(4)),
                 os.path.join(subject_output_folder, f"{subject_id}_localizer_00003.nii.gz"))
        combined_box_info[subject_id] = original_data

        # Augmentation: Apply rotation and translation using the box center as pivot.
        for _ in range(num_augmentations):
            # Generate random rotation parameters (for each view)
            rotate_sag = np.random.uniform(-20, 20)
            rotate_cor = np.random.uniform(-20, 20)
            rotate_axi = np.random.uniform(-20, 20)

            # Generate a single 3D translation (tx, ty, tz)
            trans = np.random.uniform(-20, 20, size=3)  # tx, ty, tz

            # Compute the translation and pivot for each view.
            # Note: In cv2, the order is (x, y) i.e., (column, row).
            # Sagittal view: using coordinates [Y, Z] → pivot = (Z, Y), translation = (trans[2], trans[1])
            pivot_sag = (original_center[2], original_center[1])
            trans_sag = (trans[2], trans[1])
            # Coronal view: using coordinates [X, Z] → pivot = (Z, X), translation = (trans[2], trans[0])
            pivot_cor = (original_center[2], original_center[0])
            trans_cor = (trans[2], trans[0])
            # Axial view: using coordinates [X, Y] → pivot = (Y, X), translation = (trans[1], trans[0])
            pivot_axi = (original_center[1], original_center[0])
            trans_axi = (trans[1], trans[0])

            # Apply rotation and translation for each view (using the specified pivot)
            sag_aug = augment_image(sagittal, rotate_sag, translation=trans_sag, pivot=pivot_sag)
            cor_aug = augment_image(coronal, rotate_cor, translation=trans_cor, pivot=pivot_cor)
            axi_aug = augment_image(axial, rotate_axi, translation=trans_axi, pivot=pivot_axi)

            # Update box information:
            # - Update the 3D center by adding the translation.
            # - Update each view's rotation angle separately.
            new_center = (original_center + trans).tolist()
            new_angles = (angles + np.array([rotate_sag, rotate_cor, rotate_axi])).tolist()

            # Save augmented data: use aug_id as folder name, and filenames 00001: Sagittal, 00002: Coronal, 00003: Axial
            aug_id = f"{augmented_id}"
            aug_folder = os.path.join(output_base_path, aug_id)
            os.makedirs(aug_folder, exist_ok=True)
            nib.save(nib.Nifti1Image(sag_aug.astype(np.float32), np.eye(4)),
                     os.path.join(aug_folder, f"{aug_id}_localizer_00001.nii.gz"))
            nib.save(nib.Nifti1Image(cor_aug.astype(np.float32), np.eye(4)),
                     os.path.join(aug_folder, f"{aug_id}_localizer_00002.nii.gz"))
            nib.save(nib.Nifti1Image(axi_aug.astype(np.float32), np.eye(4)),
                     os.path.join(aug_folder, f"{aug_id}_localizer_00003.nii.gz"))
            
            combined_box_info[aug_id] = {
                "center": new_center,  # Updated 3D center
                "size": size.tolist(),
                "angles": new_angles  # Updated rotation angles for each view
            }
            augmented_id += 1

    with open(os.path.join(output_base_path, "combined_ground_truth.json"), 'w') as f:
        json.dump(combined_box_info, f, indent=4)
    print(f"Augmentation complete. Data saved to {output_base_path}")

def display_augmented_data(sagittal, coronal, axial, center, size, angles, subject_id):
    """
    Visualize the augmented data by displaying the sagittal, coronal, and axial views
    with the corresponding oriented bounding boxes.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Add the subject ID as the title for the entire figure
    fig.suptitle(f"Subject ID: {subject_id}", fontsize=16)
    
    # Sagittal view (Y-Z plane)
    sag_center_2d = np.array(center)[1:3]  # [Y, Z]
    sag_size = np.array(size)[1:3]
    yz_corners = get_obb_corners(sag_center_2d, sag_size, angles[0])
    axes[0].imshow(sagittal.T, cmap='gray', origin='lower')
    axes[0].add_patch(Polygon(yz_corners, edgecolor='r', fill=False))
    axes[0].set_title('Sagittal')
    
    # Coronal view (X-Z plane)
    cor_center_2d = np.array([center[0], center[2]])  # [X, Z]
    cor_size = np.array([size[0], size[2]])
    xz_corners = get_obb_corners(cor_center_2d, cor_size, angles[1])
    axes[1].imshow(coronal.T, cmap='gray', origin='lower')
    axes[1].add_patch(Polygon(xz_corners, edgecolor='r', fill=False))
    axes[1].set_title('Coronal')
    
    # Axial view (X-Y plane)
    axi_center_2d = np.array(center)[:2]  # [X, Y]
    axi_size = np.array(size)[:2]
    xy_corners = get_obb_corners(axi_center_2d, axi_size, angles[2])
    axes[2].imshow(axial.T, cmap='gray', origin='lower')
    axes[2].add_patch(Polygon(xy_corners, edgecolor='r', fill=False))
    axes[2].set_title('Axial')
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

def display_data_from_json(output_base_path, json_path):
    """
    Load and display augmented data using the updated box information stored in the JSON file.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    for entry in data:
        folder = entry  # subject ID or augmented ID
        folder_path = os.path.join(output_base_path, folder)
        if not os.path.exists(folder_path):
            continue
        sag_path = os.path.join(folder_path, f"{folder}_localizer_00001.nii.gz")
        cor_path = os.path.join(folder_path, f"{folder}_localizer_00002.nii.gz")
        axi_path = os.path.join(folder_path, f"{folder}_localizer_00003.nii.gz")
        if not (os.path.exists(sag_path) and os.path.exists(cor_path) and os.path.exists(axi_path)):
            continue
        sag = nib.load(sag_path).get_fdata().squeeze()
        cor = nib.load(cor_path).get_fdata().squeeze()
        axi = nib.load(axi_path).get_fdata().squeeze()
        info = data[entry]
        center = np.array(info['center'])
        size = np.array(info['size'])
        angles = np.array(info['angles'])
        display_augmented_data(sag, cor, axi, center, size, angles, subject_id=folder)

if __name__ == "__main__":
    source_base = "/32CH_LOCALIZER_0001"
    output_base = "/niix_Data_fMRS_local_struct_augmentation"
    box_info = "/obb_box_information_localizer.json"
    subjects = [
        '4102_1', '4102_2', '4105_1', '4105_2', '4106', '4107', '4108_1', '4108_2',
        '4109', '4112', '4113', '4115', '4116', '4117', '4118', '4125', '4126',
        '4128', '4130', '4131', '4133', '4134', '4135', '4137', '4139', '4140',
        '4141', '4142', '4143', '4144', '4146', '4148', '4149', '4150', '4151',
        '4204', '4205', '4206', '4207', '4208'
    ]

    # If the number of augmentations is set to 0, only the original data is saved.
    augment_and_save_data(subjects, source_base, output_base, box_info, num_augmentations=30)
    # display_data_from_json(output_base, os.path.join(output_base, "combined_ground_truth.json"))


  check_for_updates()


Augmentation complete. Data saved to /home/sseok24/Desktop/21011961_seok/niix_Data_fMRS_local_struct_augmentation_final_fitmatching111_copy


In [2]:
def get_T(P):
    return np.array([[1, 0, 0, P[0]],
                     [0, 1, 0, P[1]],
                     [0, 0, 1, P[2]],
                     [0, 0, 0, 1]], dtype=np.float32)

def get_R1(P):
    rx = np.deg2rad(P[3])
    return np.array([[1, 0, 0, 0],
                     [0, np.cos(rx), np.sin(rx), 0],
                     [0, -np.sin(rx), np.cos(rx), 0],
                     [0, 0, 0, 1]], dtype=np.float32)

def get_R2(P):
    ry = np.deg2rad(P[4])
    return np.array([[np.cos(ry), 0, np.sin(ry), 0],
                     [0, 1, 0, 0],
                     [-np.sin(ry), 0, np.cos(ry), 0],
                     [0, 0, 0, 1]], dtype=np.float32)

def get_R3(P):
    rz = np.deg2rad(P[5])
    return np.array([[np.cos(rz), np.sin(rz), 0, 0],
                     [-np.sin(rz), np.cos(rz), 0, 0],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]], dtype=np.float32)

def get_Z(P):
    return np.array([[P[6], 0, 0, 0],
                     [0, P[7], 0, 0],
                     [0, 0, P[8], 0],
                     [0, 0, 0, 1]], dtype=np.float32)

# Here, to incorporate all rotation components (rx, ry, rz), we combine three rotation matrices in sequence.
def get_affine(P):
    # Combined rotation: R_total = R3 * R2 * R1
    R_total = get_R3(P) @ get_R2(P) @ get_R1(P)
    T_mat = get_T(P)
    Z_mat = get_Z(P)
    # The reason for using the transpose of the rotation matrix is due to coordinate system transformation.
    A = T_mat @ R_total.T @ Z_mat
    # Convert to a 3x4 matrix by excluding the homogeneous row.
    return A[0:3, :]

# Load the box_info file
box_info_path = "/combined_ground_truth.json"
with open(box_info_path, 'r') as f:
    box_info = json.load(f)

affine_dict = {}
# For each subject id in box_info, calculate the affine matrix
for subject_id, info in box_info.items():
    # Convert center, angles, and size information to numpy arrays
    center = np.array(info["center"])
    angles = np.array(info["angles"])
    size   = np.array(info["size"])
    # Combine into 9 coefficients
    P = np.concatenate([center, angles, size]).astype(np.float32)
    # Calculate the affine matrix (3x4)
    affine = get_affine(P)
    # Convert to list format for JSON saving
    affine_dict[subject_id] = affine.tolist()

# Save the results to a JSON file (e.g., subject_affine_matrices.json)
output_path = "/combined_ground_truth_affine.json"
with open(output_path, 'w') as f:
    json.dump(affine_dict, f, indent=4)

print(f"Saved affine matrices for {len(affine_dict)} subjects to {output_path}")


Saved affine matrices for 1240 subjects to /home/sseok24/Desktop/21011961_seok/niix_Data_fMRS_local_struct_augmentation_final_fitmatching111_copy/combined_ground_truth_affine.json
