# ProjectB Colab Notebook
This notebook combines code from `projectB.py`, `utils.py`, and `variational_network.py`.


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import scipy.io

## Utility Functions

In [None]:
import numpy as np
import torch
import torch.nn.functional as F

def k2i(img, dims=(0, )):
    dim_img = img.shape
    if dims is None:
        factor = np.prod(dim_img)
        return np.sqrt(factor) * np.fft.fftshift(np.fft.ifftn(np.fft.ifftshift(img)))
    else:
        for dim in dims:
            img = np.sqrt(dim_img[dim]) * np.fft.fftshift(
                np.fft.ifft(np.fft.ifftshift(img, axes=dim), axis=dim), axes=dim
            )
        return img

def i2k(img, dims=(0, )):
    dim_img = img.shape
    if dims is None:
        factor = np.prod(dim_img)
        return (1/np.sqrt(factor)) * np.fft.fftshift(np.fft.fftn(np.fft.ifftshift(img)))
    else:
        for dim in dims:
            img = (1/np.sqrt(dim_img[dim])) * np.fft.fftshift(
                np.fft.fft(np.fft.ifftshift(img, axes=dim), axis=dim), axes=dim
            )
        return img

def vtv_loss(x):
    """
    Computes the Vectorial Total Variation (VTV) loss as described in Eq. (19) of the script.

    Args:
        x (torch.Tensor): A dynamic image sequence of shape (T, H, W),
                          where T is the number of frames, H is the height, and W is the width.

    Returns:
        torch.Tensor: The VTV loss value (scalar).
    """
    T, H, W = x.shape

    # Compute spatial gradients for each frame (along spatial axes)
    grad_x = x[:, :, 1:] - x[:, :, :-1]       # horizontal gradient (along width)
    grad_y = x[:, 1:, :] - x[:, :-1, :]       # vertical gradient (along height)

    # Pad to (T, H, W) for consistent size
    grad_x = F.pad(grad_x, (0, 1), mode='replicate')   # pad last column (W dimension)
    grad_y = F.pad(grad_y, (0, 0, 0, 1), mode='replicate')  # pad last row (H dimension)

    # Compute squared gradient magnitude across time (sum over T)
    squared_gradients = grad_x**2 + grad_y**2
    vtv_map = squared_gradients.sum(dim=0)  # sum over T, shape: (H, W)

    # Take sqrt and sum over spatial dimensions
    vtv = torch.sqrt(vtv_map + 1e-8).sum()

    return vtv

def generate_undersampling_mask(shape, acceleration, center_fraction=0.1, sigma=10):
    """
    Generates a 3D undersampling mask with incoherent sampling across time and a Laplace-shaped density.

    Args:
        shape: (H, W, T) — dimensions of k-space data
        acceleration: desired acceleration factor (e.g., 4 or 6)
        center_fraction: fraction of central k-space to fully sample
        sigma: width parameter for Laplace-shaped density (controls decay)

    Returns:
        mask: numpy array of shape (H, W, T) with 0 (unsampled) or 1 (sampled)
    """
    H, W, T = shape
    mask = np.zeros((H, W, T), dtype=np.float32)

    # Sample the center of k-space
    center_size_h = int(H * center_fraction)
    center_size_w = int(W * center_fraction)
    ch_start, ch_end = H//2 - center_size_h//2, H//2 + center_size_h//2
    cw_start, cw_end = W//2 - center_size_w//2, W//2 + center_size_w//2

    mask[ch_start:ch_end, cw_start:cw_end, :] = 1  # fully sample center for all time frames

    # Create Laplace-shaped density for outer k-space sampling
    ky = np.arange(H) - H//2
    kx = np.arange(W) - W//2
    ky_grid, kx_grid = np.meshgrid(ky, kx, indexing='ij')
    distance = np.sqrt(ky_grid**2 + kx_grid**2)

    laplace_density = np.exp(-distance / sigma)
    laplace_density[ch_start:ch_end, cw_start:cw_end] = 0  # exclude fully sampled center

    # Normalize to sum to 1
    laplace_density /= laplace_density.sum()

    num_samples = int(H * W / acceleration) - (center_size_h * center_size_w)

    for t in range(T):
        # Flatten the density and sample points according to the probability
        flat_density = laplace_density.flatten()
        sampled_indices = np.random.choice(H * W, num_samples, replace=False, p=flat_density)
        sampled_coords = np.unravel_index(sampled_indices, (H, W))
        mask[sampled_coords[0], sampled_coords[1], t] = 1

    return mask

def compute_noisy_undersampled_measurements(img, mask, sigma=0.01):
    """
    Computes the noisy undersampled k-space measurements s = M (F x + N(0, sigma I)).

    Args:
        img (ndarray): Input image in spatial domain, shape (H, W, T)
        mask (ndarray): Binary undersampling mask, shape (H, W, T)
        sigma (float): Standard deviation of Gaussian noise

    Returns:
        s (ndarray): Noisy undersampled k-space data, shape (H, W, T)
    """
    # Compute Fourier transform of the image
    kspace_full = i2k(img, dims=(0, 1))  # apply FT along spatial dimensions only (H, W)

    # Add Gaussian noise (same shape as k-space)
    noise = np.random.normal(0, sigma, kspace_full.shape) + 1j * np.random.normal(0, sigma, kspace_full.shape)
    kspace_noisy = kspace_full + noise

    # Apply sampling mask (element-wise multiplication)
    s = mask * kspace_noisy

    return s


def complex_from_tensor(t):
    """Convert a real-valued tensor to a complex-valued tensor.

    The input is expected to have the real/imaginary components stacked in
    the first dimension, i.e. ``(2, T, H, W)`` or ``(B, 2, T, H, W)`` where ``B``
    denotes an optional batch dimension.  The returned tensor has the complex
    dimension removed and the time dimension moved to the end resulting in
    shapes ``(H, W, T)`` or ``(B, H, W, T)`` respectively.
    """

    if t.ndim == 4:
        # (2, T, H, W) -> (H, W, T)
        return torch.view_as_complex(t.permute(2, 3, 1, 0).contiguous())
    elif t.ndim == 5:
        # (B, 2, T, H, W) -> (B, H, W, T)
        return torch.view_as_complex(t.permute(0, 3, 4, 2, 1).contiguous())
    else:
        raise ValueError(
            "Tensor must have shape (2, T, H, W) or (B, 2, T, H, W)")


def tensor_from_complex(c):
    """Convert complex tensor to a real-valued representation.

    Accepts a tensor of shape ``(H, W, T)`` or ``(B, H, W, T)`` and returns a
    real tensor with the complex dimension stacked in the first position,
    resulting in shapes ``(2, T, H, W)`` or ``(B, 2, T, H, W)`` respectively.
    """

    real = torch.view_as_real(c)
    if c.ndim == 3:
        # (H, W, T, 2) -> (2, T, H, W)
        return real.permute(3, 2, 0, 1)
    elif c.ndim == 4:
        # (B, H, W, T, 2) -> (B, 2, T, H, W)
        return real.permute(0, 4, 3, 1, 2)
    else:
        raise ValueError(
            "Complex tensor must have shape (H, W, T) or (B, H, W, T)")


def i2k_torch(x):
    """Fourier transform from image to k-space.

    The function operates on tensors with shape ``(..., T, H, W)`` where the
    last three dimensions correspond to time, height and width.  The Fourier
    transform is applied over the spatial dimensions resulting in an output of
    shape ``(..., H, W, T)``.
    """

    # Move the time dimension to the end so that H and W are the last two dims
    x_c = torch.movedim(x, -3, -1)
    k = torch.fft.fftn(
        torch.fft.ifftshift(x_c, dim=(-3, -2)),
        dim=(-3, -2),
        norm="ortho",
    )
    k = torch.fft.fftshift(k, dim=(-3, -2))
    return k


def k2i_torch(k):
    """Inverse Fourier transform from k-space to image.

    Accepts tensors with shape ``(..., H, W, T)`` and returns real-valued image
    tensors with shape ``(..., T, H, W)``.  The inverse transform is performed
    over the spatial dimensions.
    """

    img = torch.fft.ifftn(
        torch.fft.ifftshift(k, dim=(-3, -2)),
        dim=(-3, -2),
        norm="ortho",
    )
    img = torch.fft.fftshift(img, dim=(-3, -2))
    img = torch.movedim(img, -1, -3).real
    return img


## Variational Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VariationalNetwork(nn.Module):
    def __init__(self, n_layers=10, n_filters=1, filter_size=3):
        super(VariationalNetwork, self).__init__()
        self.n_layers = n_layers

        # Learnable step sizes and regularization weights
        self.alpha = nn.ParameterList([nn.Parameter(torch.tensor(0.1)) for _ in range(n_layers)])
        self.mu = nn.ParameterList([nn.Parameter(torch.tensor(0.9)) for _ in range(n_layers)])
        
        self.filters = nn.Parameter(torch.randn(n_layers, n_filters, filter_size, filter_size) * 0.01)
        
        # activation function (lookup table)
        self.activation_grid = torch.linspace(0, 1.5, steps=100)  # shape: (G,)
        self.activation_values = nn.Parameter(torch.rand(n_layers, n_filters, 100))  # shape: (K, F, G)

    def apply_activation(self, x, k, i):
        """
        Interpolate learnable activation function φ^{k,i} over fixed grid
        Args:
            x: input tensor (H, W)
            k: layer index
            i: filter index
        Returns:
            activated: same shape as x
        """
        grid = self.activation_grid.to(x.device)  # shape (G,)
        values = self.activation_values[k, i, :].to(x.device)  # shape (G,)

        # Normalize z to grid range [0, 1.5]
        z_clamped = x.clamp(min=0, max=1.5)

        idx = torch.bucketize(z_clamped.reshape(-1), grid)
        idx = torch.clamp(idx, 1, grid.numel() - 1)
        x0 = grid[idx - 1]
        x1 = grid[idx]
        y0 = values[idx - 1]
        y1 = values[idx]
        slope = (y1 - y0) / (x1 - x0)
        activated = y0 + slope * (z_clamped.reshape(-1) - x0)
        return activated.reshape(z_clamped.shape)
    
        
    def reg_vtv(self, x, k):
        """Compute the Vectorial Total Variation regularizer.

        The method supports inputs of shape ``(T, H, W)`` or ``(B, T, H, W)``.
        In the batched case the computation is performed independently for each
        element in the batch.
        """

        batched = x.dim() == 4
        if batched:
            B, T, H, W = x.shape
            reg_term = torch.zeros_like(x)
            x_reshaped = x.reshape(B * T, 1, H, W)
        else:
            T, H, W = x.shape
            reg_term = torch.zeros_like(x)
            x_reshaped = x.unsqueeze(1)  # (T,1,H,W)

        for i in range(self.filters.shape[1]):
            # Get the (k,i)-th filter
            filt = self.filters[k, i, :, :].unsqueeze(0).unsqueeze(0)  # (1,1,f,f)

            # Convolve each frame (batch-wise convolution)
            filtered = F.conv2d(
                x_reshaped,
                filt,
                padding="same",
                groups=1,
            )
            if batched:
                filtered_frames = filtered.view(B, T, H, W)
            else:
                filtered_frames = filtered.squeeze(1)  # (T,H,W)

            # Compute joint gradient magnitude across time
            if batched:
                magnitude = (filtered_frames ** 2).sum(dim=1) / T  # (B,H,W)
            else:
                magnitude = (filtered_frames ** 2).sum(dim=0) / T  # (H,W)
            magnitude = torch.sqrt(magnitude + 1e-8)

            # Apply learnable activation function
            phi = self.apply_activation(magnitude, k, i)  # (H, W)

            # Transpose operation: convolution with flipped filter
            flipped_filt = torch.flip(filt, dims=[2, 3])

            for t in range(T):
                if batched:
                    tmp = filtered_frames[:, t] * phi  # (B,H,W)
                    tmp_conv = F.conv2d(
                        tmp.unsqueeze(1),
                        flipped_filt,
                        padding="same",
                    ).squeeze(1)  # (B,H,W)
                    reg_term[:, t] += tmp_conv
                else:
                    tmp = filtered_frames[t] * phi  # (H,W)
                    tmp_conv = F.conv2d(
                        tmp.unsqueeze(0).unsqueeze(0),
                        flipped_filt,
                        padding="same",
                    ).squeeze(0).squeeze(0)
                    reg_term[t] += tmp_conv


        return reg_term

    def reg_tv(self, x):
        """
        Isotropic total variation regularizer.

        Args:
            x: input tensor of shape ``(T, H, W)``

        Returns:
            reg_term: divergence of normalized gradients with the same shape as ``x``
        """

        grad_x = x[:, :, 1:] - x[:, :, :-1]
        grad_y = x[:, 1:, :] - x[:, :-1, :]

        grad_x = F.pad(grad_x, (0, 1), mode="replicate")
        grad_y = F.pad(grad_y, (0, 0, 0, 1), mode="replicate")

        magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
        grad_x_norm = grad_x / magnitude
        grad_y_norm = grad_y / magnitude

        div_x = grad_x_norm - F.pad(grad_x_norm[:, :, :-1], (1, 0), mode="replicate")
        div_y = grad_y_norm - F.pad(grad_y_norm[:, :-1, :], (0, 0, 1, 0), mode="replicate")

        return div_x + div_y

    def reg_tikhonov(self, x):
        """
        Tikhonov regularizer implemented via a discrete Laplacian.

        Args:
            x: input tensor of shape ``(T, H, W)``

        Returns:
            reg_term: Laplacian of ``x`` with the same shape as ``x``
        """

        kernel = torch.tensor([[0., -1., 0.],
                              [-1., 4., -1.],
                              [0., -1., 0.]], device=x.device, dtype=x.dtype)
        kernel = kernel.view(1, 1, 3, 3)
        reg_term = F.conv2d(x.unsqueeze(1), kernel, padding=1)
        return reg_term.squeeze(1)
    
    

    def compute_g(self, x, s, M, F, FH, k):
        """
        Compute g^k = α^k * F^H(M * (F x^k - s)) + Reg^k(x^k)
        
        Args:
            x: current iterate (T, H, W)
            s: measured data (H, W, T)
            M: sampling mask (H, W, T)
            F: forward operator (function)
            FH: adjoint operator (function)
            k: current layer index
            
        Returns:
            g: tensor of shape (T, H, W)
        """
        # 1. Apply forward Fourier transform to x (output: H, W, T)
        x_kspace = F(x)  # (H, W, T)

        # 2. Compute k-space residual
        residual = (M * x_kspace) - s  # (H, W, T)

        # 3. Backproject the residual using inverse Fourier transform
        data_term = FH(residual)  # (T, H, W)

        # 4. Compute regularization term
        reg_term = self.reg_vtv(x, k)  # (T, H, W)

        # 5. Combine with step size α^k
        g = self.alpha[k] * (data_term + reg_term)  # (T, H, W)

        return g
    
    def update_momentum(self, m_prev, g, k):
        """
        Compute m^{k+1} = μ^{k+1} * m^k + g^k

        Args:
            m_prev: previous momentum (T, H, W)
            g: current gradient (T, H, W)
            k: current layer index

        Returns:
            m_next: updated momentum (T, H, W)
        """
        return self.mu[k] * m_prev + g

    def update_x(self, x, m_next):
        """
        Compute x^{k+1} = x^k - m^{k+1}

        Args:
            x: current estimate (T, H, W)
            m_next: updated momentum (T, H, W)

        Returns:
            x_next: updated image estimate (T, H, W)
        """
        return x - m_next


    def forward(self, x0, s, i2k, k2i, mask, return_intermediate=False):

        x = x0.clone()
        m = torch.zeros_like(x0)  # Initialize momentum

        xs = []
        for k in range(self.n_layers):

            g = self.compute_g(x, s, mask, i2k, k2i, k)
            m = self.update_momentum(m, g, k)
            x = self.update_x(x, m)

            if return_intermediate:
                xs.append(x.clone())

        if return_intermediate:
            return x, xs
        return x


## Dataset, Training, and Utilities

In [None]:
from variational_network import *
from utils import *
import scipy.io
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import os

#─────────────────────────────────────────────────────────────────────────────
# Hyperparameters
N_LAYERS   = 8
N_FILTERS  = 5
FILTER_SZ  = 3

NOISE_STD          = 0.01
ACCEL_RATE         = 4
MASK_CENTER_RADIUS = 8

BATCH_SIZE   = 4
NUM_EPOCHS   = 20
LR           = 1e-2
PRINT_EVERY  = 10
TRAIN_SPLIT  = 0.8
DS_TAU       = 0.1
SHOW_VAL_IMAGES = True
#─────────────────────────────────────────────────────────────────────────────


class HeartDataset(Dataset):
    """Dataset generating noisy, undersampled k-space measurements from
    the 2dt_heart.mat file."""

    def __init__(self, mat_path="2dt_heart.mat", noise_std=NOISE_STD,
                 acceleration=ACCEL_RATE, center_fraction=0.1, sigma=10):
        super().__init__()
        data = scipy.io.loadmat(mat_path)
        self.imgs = data["imgs"].astype(np.float32)  # (H, W, T, N)
        self.noise_std = noise_std
        self.acceleration = acceleration
        self.center_fraction = center_fraction
        self.sigma = sigma

    def __len__(self):
        return self.imgs.shape[3]

    def __getitem__(self, idx):
        img = self.imgs[..., idx]  # (H, W, T)
        mask = generate_undersampling_mask(
            img.shape, self.acceleration, self.center_fraction, self.sigma
        )
        s = compute_noisy_undersampled_measurements(img, mask,
                                                    sigma=self.noise_std)

        img_t = torch.from_numpy(img.transpose(2, 0, 1))  # (T, H, W)
        real = np.real(s).astype(np.float32).transpose(2, 0, 1)
        imag = np.imag(s).astype(np.float32).transpose(2, 0, 1)
        s_t = torch.from_numpy(np.stack([real, imag], axis=0))  # (2, T, H, W)
        mask_t = torch.from_numpy(mask.transpose(2, 0, 1).astype(np.float32))
        return img_t, s_t, mask_t


def create_dataloaders(batch_size=BATCH_SIZE, train_split=TRAIN_SPLIT):
    """Utility to create train/validation dataloaders."""
    dataset = HeartDataset()
    n_train = int(len(dataset) * train_split)
    n_val = len(dataset) - n_train
    train_ds, val_ds = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader




def train_vn(num_epochs=NUM_EPOCHS, lr=LR, batch_size=BATCH_SIZE):
    """Train a variational network on the heart dataset and plot the loss."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, _ = create_dataloaders(batch_size=batch_size)

    vn = VariationalNetwork(n_layers=N_LAYERS, n_filters=N_FILTERS,
                            filter_size=FILTER_SZ).to(device)
    optim = torch.optim.Adam(vn.parameters(), lr=lr)

    losses = []
    for epoch in range(num_epochs):
        vn.train()
        running = 0.0
        i = 0
        for gt, s, m in train_loader:
            i += 1
            percent = (i / len(train_loader)) * 100
            print(f"Epoch {epoch+1}/{num_epochs} - Progress: {percent:.3f}%", end='\r')
            gt = gt.to(device)
            s_complex = complex_from_tensor(s).to(device)
            m = m.permute(0, 2, 3, 1).to(device)

            x0 = k2i_torch(s_complex)
            pred, preds_all = vn(x0, s_complex, i2k_torch, k2i_torch, m,
                                 return_intermediate=True)

            plot_loss = torch.mean(torch.abs(torch.abs(pred) - torch.abs(gt)))

            K = len(preds_all)
            ds_loss = 0.0
            for k, x_k in enumerate(preds_all, start=1):
                weight = torch.exp(torch.tensor(-DS_TAU * (K - k), device=gt.device, dtype=torch.float32))
                ds_loss = ds_loss + weight * F.mse_loss(x_k, gt)

            optim.zero_grad()
            ds_loss.backward()
            optim.step()
            running += plot_loss.item()

        epoch_loss = running / len(train_loader)
        losses.append(epoch_loss)
        if (epoch + 1) % PRINT_EVERY == 0:
            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.6f}")

    plt.figure()
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.title('Training Loss')
    plt.show()
    return vn, losses


def validate_vn(model, val_loader=None, batch_size=BATCH_SIZE,
                display_examples=False, num_examples=3):
    """Validate a trained variational network on the held-out set.

    Parameters
    ----------
    model : torch.nn.Module
        The trained variational network.
    val_loader : DataLoader, optional
        Validation dataloader.  If ``None``, a loader is created using
        ``create_dataloaders`` with ``batch_size``.
    batch_size : int
        Batch size to use when constructing the dataloader.
    display_examples : bool, optional
        If ``True``, display ``num_examples`` pairs of ground truth and
        reconstructed images.
    num_examples : int, optional
        Number of example pairs to display when ``display_examples`` is
        ``True``.

    Returns
    -------
    float
        The average L1 validation loss.
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if val_loader is None:
        _, val_loader = create_dataloaders(batch_size=batch_size)

    model = model.to(device)
    model.eval()
    running = 0.0
    examples = []
    with torch.no_grad():
        for gt, s, m in val_loader:
            gt = gt.to(device)
            s_complex = complex_from_tensor(s).to(device)
            m = m.permute(0, 2, 3, 1).to(device)

            x0 = k2i_torch(s_complex)
            pred = model(x0, s_complex, i2k_torch, k2i_torch, m)

            loss = torch.mean(torch.abs(torch.abs(pred) - torch.abs(gt)))
            running += loss.item()

            if display_examples and len(examples) < num_examples:
                for j in range(gt.shape[0]):
                    if len(examples) >= num_examples:
                        break
                    examples.append((gt[j].cpu(), pred[j].cpu()))

    if display_examples and examples:
        n = len(examples)
        fig, axes = plt.subplots(n, 2, figsize=(6, 3 * n))
        if n == 1:
            axes = np.array(axes).reshape(1, -1)
        for i, (gt_ex, pred_ex) in enumerate(examples):
            gt_img = torch.abs(gt_ex[0]).numpy()
            pred_img = torch.abs(pred_ex[0]).numpy()
            axes[i, 0].imshow(gt_img, cmap="gray")
            axes[i, 0].set_title("Ground Truth")
            axes[i, 0].axis("off")
            axes[i, 1].imshow(pred_img, cmap="gray")
            axes[i, 1].set_title("Reconstruction")
            axes[i, 1].axis("off")
        plt.tight_layout()
        plt.show()

    return running / len(val_loader)


def save_trained_model(model, directory="models", filename=None, **hyperparams):
    """Save a trained model to ``directory/filename``.

    The filename will encode given hyperparameters if ``filename`` is ``None``.

    Args:
        model (torch.nn.Module): trained variational network
        directory (str): directory to store the model file
        filename (str, optional): file name for the saved state dict.  If not
            provided, a name is constructed from ``hyperparams``.
        **hyperparams: keyword arguments describing the hyperparameters used
            during training.
    """

    def _sanitize(val):
        if isinstance(val, float):
            s = f"{val:.0e}" if val < 1e-3 or val >= 1e3 else f"{val:g}"
            s = s.replace("+", "")
        else:
            s = str(val)
        return s.replace(".", "p")

    os.makedirs(directory, exist_ok=True)

    if filename is None:
        parts = [f"{k}{_sanitize(v)}" for k, v in sorted(hyperparams.items())]
        filename = "vn_" + "_".join(parts) + ".pth"

    path = os.path.join(directory, filename)
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")


