In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

class BrainSliceConditionalDiffusion(nn.Module):
    """
    Implements a conditional diffusion model over 3D brain volumes by slicing
    along multiple axes and processing with a difusion 2D U-Net. Optionally performs
    inpainting where ground-truth measurements are available.
    """
    def __init__(self, image_shape, channels_indices, unet, buffer=5):
        """
        Args:
            image_shape (tuple): Shape of the input volume, e.g., (D, H, W).
            channels_indices (list[int]): Which channels in the reference images to use.
            unet (nn.Module): 2D U-Net model for slice-wise denoising.
            buffer (int): Number of extra timesteps to pad at the start/end.
        """
        super(BrainSliceConditionalDiffusion, self).__init__()
        self.unet = unet                      # 2D denoiser network
        self.buffer = buffer                  # padding around time range
        self.channels = channels_indices      # channels used for conditioning

    def inpainting(self, pred_x_0, integer_brain_coord, brain2_slices_rs, reference_image):
        """
        Replace predicted voxels at known coordinates with ground-truth values.

        Args:
            pred_x_0 (Tensor): Current clean volume prediction, shape (D,H,W,C).
            integer_brain_coord (LongTensor): Indices of known voxels, shape (N,3).
            brain2_slices_rs (Tensor): Reference slices, shape (N, num_channels).
            reference_image (Tensor): Full conditioning volume.
        Returns:
            Tensor: Updated pred_x_0 with inpainted values.
        """
        with torch.no_grad():
            # Zero out predicted values at known coords
            pred_x_0[
                integer_brain_coord[:, 0],
                integer_brain_coord[:, 1],
                integer_brain_coord[:, 2],
                :
            ] = 0
            # Insert ground-truth slice values for specified channels
            pred_x_0[
                integer_brain_coord[:, 0],
                integer_brain_coord[:, 1],
                integer_brain_coord[:, 2],
                :
            ] += brain2_slices_rs[:, self.channels].clone().float()
            # Alternative: inpaint only last channel with reference_image
            # pred_x_0[:, :, :, -1] = reference_image
        return pred_x_0

    def predict_x_0(self, t, x_t, noise_scheduler):
        """
        Use the U-Net to predict the denoised image x_0 from noisy input x_t.

        Args:
            t (int): Current timestep index.
            x_t (Tensor): Noisy input at timestep t.
            noise_scheduler: Scheduler to embed t.
        Returns:
            Tensor: Predicted clean image x_0 on CPU.
        """
        # Select device (GPU if available)
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        timesteps = torch.LongTensor([t]).to(device)
        # Forward pass through U-Net: returns (noise_pred, ...) or x0_pred
        x_0_pred = self.unet(x_t.to(device), timesteps, return_dict=False)[0] # here I used the direct diffusion prediction of x_0
        return x_0_pred.cpu()

    def step(
        self, i, t, iterator, pred_x_0,
        noise_scheduler, num_steps,
        integer_brain_coord, brain2_slices_rs, reference_image,
        batch_size, verbose=True, with_inpainting=True,
    ):
        """
        Perform one diffusion step: add noise, slice-roll-slice, denoise on batches,
        reassemble volume, and optionally inpaint known voxels.

        Args:
            i (int): Index of current step in loop.
            t (int): Diffusion timestep.
            iterator: tqdm iterator for progress updates.
            pred_x_0 (Tensor): Current predicted clean volume.
            noise_scheduler: Scheduler for noise addition/removal.
            num_steps (int): Total number of denoising steps.
            integer_brain_coord, brain2_slices_rs, reference_image: inpainting data.
            batch_size (int): Number of slices per batch.
            verbose (bool): Whether to update iterator.
            with_inpainting (bool): Flag to apply inpainting.

        Returns:
            Tensor: Updated pred_x_0 after this step.
        """
        # Ensure a device variable exists (define as in predict_x_0)
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        # 1. Sample noisy input x_t for this timestep
        timestep = torch.LongTensor([t]).to(device)
        noise = torch.randn_like(pred_x_0)
        x_t = noise_scheduler.add_noise(pred_x_0, noise, timestep)

        # 2. Define axis permutations to slice along different views
        permutations = [
            [0, 1, 2, 3],  # original volume
            [1, 2, 0, 3],  # rotate axes: H-W-D
            [2, 0, 1, 3],  # rotate axes: W-D-H
        ]

        with torch.no_grad():
            # Choose a permutation based on iteration
            permutation = permutations[i % len(permutations)]
            iterator.set_postfix({"state": "rolling"})

            # Random spatial shift to vary slice boundaries
            shifts = torch.randint(-16, 16, (3,))
            rolled = x_t.roll(list(shifts), dims=[0, 1, 2])
            rolled_ref = reference_image.clone().roll(list(shifts), dims=[0, 1, 2])

            # 3. Flatten volume into batches of 2D slices
            axial_batches = flatten(permutation, rolled)
            axial_conditioning = flatten(permutation, rolled_ref.unsqueeze(-1))

            # Preallocate prediction buffer
            pred_buf = torch.zeros_like(axial_batches)
            assert len(axial_batches) % batch_size == 0, "Batch size must divide number of slices"

            # 4. Create minibatches: shape [n_batches, batch_size, C, H, W]
            xb = axial_batches.unfold(0, batch_size, batch_size).permute(0, 3, 1, 2)
            cb = axial_conditioning.unfold(0, batch_size, batch_size).permute(0, 3, 1, 2)
            inputs = torch.cat([xb, cb], dim=2)

            # 5. Denoise each minibatch
            for j, batch in enumerate(inputs):
                start = j * batch_size
                end = (j + 1) * batch_size
                pred_buf[start:end] = self.predict_x_0(t, batch, noise_scheduler)
                iterator.set_postfix({"state": f"running batch {j}"})

            # Make contiguous for reshaping
            if not pred_buf.is_contiguous():
                pred_buf = pred_buf.contiguous()

            # 6. Unflatten back to 3D volume
            iterator.set_postfix({"state": "unflattening"})
            pred_vol = unflatten(permutation, pred_buf, x_t.shape)

            # 7. Reverse the random shift
            iterator.set_postfix({"state": "unrolling"})
            inv_shifts = [-s.item() for s in shifts]
            pred_vol = pred_vol.roll(inv_shifts, dims=[0, 1, 2])

            # 8. Optional inpainting of known measurements
            if with_inpainting:
                pred_vol = self.inpainting(
                    pred_vol, integer_brain_coord,
                    brain2_slices_rs, reference_image
                )

            return pred_vol

    def diffusion_pipeline(
        self, x_0_start, t_start, t_end, noise_scheduler,
        batch_size, integer_brain_coord, brain2_slices_rs,
        reference_image, num_steps=50, verbose=False
    ):
        """
        Run the full reverse diffusion from t_start down to t_end over num_steps.

        Args:
            x_0_start (Tensor): Initial guess for the volume.
            t_start (int): Starting (noisiest) timestep.
            t_end (int): Ending (cleanest) timestep.
            noise_scheduler: Diffusion noise scheduler.
            batch_size (int): Slices per batch in step().
            integer_brain_coord, brain2_slices_rs, reference_image: inpainting data.
            num_steps (int): Number of denoising iterations.
            verbose (bool): If True, show tqdm progress.

        Returns:
            Tensor: Final denoised volume.
        """
        with torch.no_grad():
            # Copy the initial estimate
            pred_x_0 = x_0_start.clone()

            # Define a sequence of timesteps
            timesteps = torch.linspace(t_start, t_end, num_steps + 1).int()
            iterator = tqdm(range(1, len(timesteps)), disable=not verbose)

            # Iterate through timesteps and refine prediction
            for i in iterator:
                t = timesteps[i - 1].item()
                pred_x_0 = self.step(
                    i, t, iterator, pred_x_0,
                    noise_scheduler, num_steps,
                    integer_brain_coord, brain2_slices_rs,
                    reference_image, batch_size,
                    verbose, with_inpainting=False
                )
            return pred_x_0