## Idea Illustration (with code)

In this notebook, we show the proposed procedure for super-resolution of CFD/4DF data. 

### Step 0: Convert ensight data to vtk data

Before conducting the analysis, we need to convert the ensight data to vtk data. Here is an example.

In [None]:
import os
import numpy as np
import vtk
import vtk.numpy_interface.dataset_adapter as dsa
wdo = dsa.WrapDataObject
from tqdm import tqdm

wd = r'/data/ANY-011-001-ori/'  ## change here for other data sets
ensightFolder = '/'
outputFolder = 'ens-vtk'
filename = 'ensight.encas'
outputBaseName = 'ANY-011-001'  ## change here to match wd

wssComponentArrayNames = ['x_wall_shear',
                          'y_wall_shear',
                          'z_wall_shear']

def get_block_names(dataset: vtk.vtkMultiBlockDataSet):
    numBlocks = dataset.GetNumberOfBlocks()
    blockNames = [None] * numBlocks
    for i in range(numBlocks):
        blockNames[i] = dataset.GetMetaData(i).Get(
                vtk.vtkMultiBlockDataSet.NAME())
    return blockNames

def split_multiblock_dataset_by_name(dataset: vtk.vtkMultiBlockDataSet, 
                                     blockNames: list=[]):
    if not blockNames:
        blockNames = get_block_names(dataset)
        
    blocks = {}
    for i, name in enumerate(blockNames):
        blocks[name] = dataset.GetBlock(i)
    
    return blocks

def dataset_surface_filter(dataset: vtk.vtkUnstructuredGrid):
    surfaceFilter = vtk.vtkDataSetSurfaceFilter()
    surfaceFilter.SetInputData(dataset)
    surfaceFilter.Update()
    return surfaceFilter.GetOutput()

print('Reading ensight case file...')    

filepath = os.path.join(wd,  filename)

reader = vtk.vtkEnSightGoldBinaryReader()
reader.SetCaseFileName(filepath)
reader.Update()

print('Updating reader...')

np.bool = np.bool_
timeset = dsa.vtkDataArrayToVTKArray(reader.GetTimeSets().GetItem(0))
print('Reading time series...')
datasets = []


writer = vtk.vtkXMLMultiBlockDataWriter()
writer.SetCompressorTypeToZLib()

if not outputBaseName:
    outputBaseName = os.path.splitext(filename)[0]


numTimePoints = timeset.size
numDigits = len(str(numTimePoints))

for i, ti in enumerate(tqdm(timeset)):
    reader.SetTimeValue(ti)
    reader.Update()
    dataset = reader.GetOutput()
      
    
    dataset_dsa = wdo(dataset)
    dataset_dsa.FieldData.append(ti, 'Time')
    
    indexStr = f'{i}'.zfill(numDigits)
    outputFileName = f'{outputBaseName}_{indexStr}.vtm'
    outputFilePath = os.path.join(wd, outputFolder, outputFileName)
    
    writer.SetInputData(dataset)
    writer.SetFileName(outputFilePath)
    
    writer.Update()

timeFileName = f'{outputBaseName}_time.timeset'
with open(os.path.join(wd, outputFolder, timeFileName), 'w') as f:
    for ti in timeset:
        f.write(f'{ti}\n')


After running the code, we obtain a folder named "ens-vtk" in the working directory. We should do this to all the 10 patients' original data. 

### Step 1: Obtain data from vtk files

The next step is to obtain the data we will use in the analysis step from the vtk files. Here is an example. 

In [None]:
import vtk
import numpy as np
from vtk.util.numpy_support import vtk_to_numpy

reader = vtk.vtkXMLMultiBlockDataReader()
reader.SetFileName("/data/ANY-011-001-ori/ens-vtk/ANY-011-001_00.vtm")  ## change here for other data sets
reader.Update()

## extract data from vtk file

def vtk_to_numpy_array(vtk_array):
    return vtk.util.numpy_support.vtk_to_numpy(vtk_array)

# Extract blocks
blocks = [reader.GetOutput().GetBlock(i) for i in range(reader.GetOutput().GetNumberOfBlocks())]

# Initialize lists to store data
points_list = []
point_data_list = []

for block in blocks:
    if isinstance(block, vtk.vtkUnstructuredGrid):
        points = vtk_to_numpy_array(block.GetPoints().GetData())
        point_data = vtk_to_numpy_array(block.GetPointData().GetArray('velocity'))

        points_list.append(points)
        point_data_list.append(point_data)

# Convert lists to single numpy arrays
all_points = np.vstack(points_list)
all_point_data = np.vstack(point_data_list)

#np.savez('ANY-011-001_00.npz', all_points=all_points, all_points_data=all_point_data)

A few things to note:

- The data we will use in each vtm file are: (i) all_points, which stores the spatial coordinates $(x,y,z)$ of each measurement. (ii) all_point_data, which stores the velocities $(v_x,v_y,v_z)$ at each measured point. 

- Each vtk folder contains multiple vtm files, here we only deal with one. We propose to retrieve all_points and all_point_data for all the vtm files, and average them. By doing this, we will have one single file for each patient. 

- Based on personal experience, installing the python vtk package is a little painful. It requires a specific python version and seems not supporting the latest one. Therefore, I recomend doing this step on a local computer, save the results and then upload to the server. It would not take too long.  

### Step 2: Processing data

In this step, we create image-like data for training and testing. The procedure can be described as follows:

- For each measurement, we obtain a local $\epsilon$-ball of it, based on its Euclidean distance with other measurements. 

- Depending on how many measurements are contained in the local $\epsilon$-ball, we can calculate a resolution parameter $t$, such that a larger value means higher resolution. Specifically, $t=\lfloor\frac{\log_2(N)}{3} \rfloor$, where $N$ is the number of measurements in a local $\epsilon$-ball. 

- Then, we create $2^t\times 2^t\times 2^t$ voxels. The channel values cooresponding to each voxel are the velocities. For all the measurements in a voxel, we impute a function over the cube by inverse distance weighting. The channel values of the voxel is the function evaluated at the center of the cube. We normalize the velocities such that the value is in between -1 to 1. 

- To make the tensor having the same size, we mannually extend the voxel size to be $2^T\times 2^T\times 2^T$, where $T$ is the highest resolution. The channel vaules of the extended voxels are imputed by nearest neighbor. 

By doing this, we will obtain $n$ $3\times 2^T\times 2^T\times 2^T$ tensors, where $n$ is the total number of measurements in a vtm file. Each $3\times 2^T\times 2^T\times 2^T$ represents a local "image" of the corresponding data point, associated with a resolution parameter $t$.  

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from scipy.spatial import cKDTree
from scipy.ndimage import zoom  # For resampling
import os
import math
from tqdm.auto import tqdm  
from pathos.multiprocessing import ProcessingPool as Pool

def process_point(args):
    idx, all_data, all_data_point, epsilon, tree, fixed_grid_size = args
    point = all_data[idx]
    indices = tree.query_ball_point(point, r=epsilon)
    if idx not in indices:
        indices.append(idx)
    selected_points = all_data[indices]
    num_neighbors = len(indices)
    t = max(1, int(math.log2(num_neighbors) // 3))

    x_min, y_min, z_min = selected_points.min(axis=0)
    x_max, y_max, z_max = selected_points.max(axis=0)

    num_partitions = 2 ** t
    x_edges = np.linspace(x_min, x_max, num_partitions + 1)
    y_edges = np.linspace(y_min, y_max, num_partitions + 1)
    z_edges = np.linspace(z_min, z_max, num_partitions + 1)
    x_centers = (x_edges[:-1] + x_edges[1:]) / 2
    y_centers = (y_edges[:-1] + y_edges[1:]) / 2
    z_centers = (z_edges[:-1] + z_edges[1:]) / 2
    Xc, Yc, Zc = np.meshgrid(
        x_centers, y_centers, z_centers, indexing='ij'
    )
    centers = np.column_stack((Xc.ravel(), Yc.ravel(), Zc.ravel()))
        
    channel_x = np.zeros(centers.shape[0])
    channel_y = np.zeros(centers.shape[0])
    channel_z = np.zeros(centers.shape[0])
    
    ## normalize velocities
    selected_values = all_data_point[indices]
    min_vals = selected_values.min(axis=0)
    max_vals = selected_values.max(axis=0)
    ranges = max_vals - min_vals
    ranges[ranges == 0] = 1
    selected_val_scaled = (selected_values - min_vals) / ranges
    selected_values = selected_val_scaled * 2 - 1

    for i, center in enumerate(centers):
        distances = np.linalg.norm(selected_points - center, axis=1)
        if np.any(distances == 0):
            idx_zero = np.where(distances == 0)[0][0]
            channel_x[i] = selected_values[idx_zero, 0]
            channel_y[i] = selected_values[idx_zero, 1]
            channel_z[i] = selected_values[idx_zero, 2]
        else:
            weights = 1 / distances
            weights /= weights.sum()
            channel_x[i] = np.dot(weights, selected_values[:, 0])
            channel_y[i] = np.dot(weights, selected_values[:, 1])
            channel_z[i] = np.dot(weights, selected_values[:, 2])
    
    channel_values = np.stack((channel_x, channel_y, channel_z), axis=-1)   
    tensor_shape = (num_partitions, num_partitions, num_partitions, 3)
    tensor = channel_values.reshape(tensor_shape)
    
    # Resample tensor to fixed grid size using nearest neighbor interpolation

    zoom_factors = [fixed_size / float(orig_size) for fixed_size, orig_size in zip(fixed_grid_size, tensor.shape[:3])]
    # Apply zoom with order=0 for nearest neighbor interpolation
    tensor_resized = zoom(tensor, zoom_factors + [1], order=0)
    tensor_resized = torch.from_numpy(tensor_resized).float()

    return (tensor_resized, t, idx)

def process_all_points_parallel(
    all_data, all_data_point, epsilon, batch_size=32, save_dir='tensor_batches', fixed_grid_size=(32, 32, 32)
):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    N = all_data.shape[0]
    tree = cKDTree(all_data)

    args_list = [
        (i, all_data, all_data_point, epsilon, tree, fixed_grid_size) for i in range(N)
    ]
    batch = []
    batch_idx = 0
    num_batches = 0

    # Use pathos Pool for parallel execution
    with Pool(processes=os.cpu_count()) as pool:
        
        for result in tqdm(
            pool.imap(process_point, args_list), total=N, desc='Processing'
        ):
            if result is not None:
                tensor, t, idx = result
                batch.append((tensor, t, idx))
                batch_idx += 1
                if batch_idx >= batch_size:
                    save_path = os.path.join(
                        save_dir, f'batch_{num_batches}.pt'
                    )
                    torch.save(batch, save_path)
                    batch = []
                    batch_idx = 0
                    num_batches += 1

    # Save any remaining tensors in the final batch
    if batch:
        save_path = os.path.join(save_dir, f'batch_{num_batches}.pt')
        torch.save(batch, save_path)
        num_batches += 1

    return num_batches

# Define a custom Dataset to load the saved batches
class VoxelDataset(Dataset):
    def __init__(self, save_dir='tensor_batches'):
        self.save_dir = save_dir
        self.batch_files = [
            os.path.join(save_dir, f)
            for f in os.listdir(save_dir)
            if f.endswith('.pt')
        ]
        self.batch_files.sort()
        self.index_map = []
        self._create_index_map()

    def _create_index_map(self):
        for batch_file in self.batch_files:
            batch = torch.load(batch_file)
            batch_size = len(batch)
            for i in range(batch_size):
                self.index_map.append((batch_file, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        batch_file, tensor_idx = self.index_map[idx]
        batch = torch.load(batch_file)
        tensor, t, point_idx = batch[tensor_idx]
        return tensor, t  # Return both tensor and t

# Usage
if __name__ == '__main__':
    # Assume all_points and all_point_data are defined elsewhere
    all_data = all_points
    all_data_point = all_point_data
    epsilon = 0.001  # Define the epsilon radius
    batch_size = 32  # Define the batch size
    fixed_grid_size = (16, 16, 16)  # Define the fixed grid size for resampling

    # Process all points in parallel and save tensors in batches with progress bar
    num_batches = process_all_points_parallel(
        all_data, all_data_point, epsilon, batch_size=batch_size, fixed_grid_size=fixed_grid_size
    )
    print(f"Saved {num_batches} batches of tensors.")

    # Create a Dataset and DataLoader for PyTorch
    dataset = VoxelDataset(save_dir='tensor_batches')
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Iterate over the dataloader
    for batch_tensors, batch_t_values in dataloader:
        print(f"Batch tensors shape: {batch_tensors.shape}")
        print(f"Unique batch t values: {batch_t_values.unique()}")

### Step 3: Train the model

Now, we train the model. Specificallly, we will learn a function $g$, which improve the resolution of an image by increase its number of voxels by 8 times (2 times in each dimension). 

- We assume a pre-additive noise model, defined as $\mathbf{X}_t=g(\mathbf{X}_{t-1}+\boldsymbol{\epsilon}_{t-1})$, where $\boldsymbol{\epsilon}_{t-1}\sim N(0,\sigma_{t-1}^2\mathbf{I})$. Currently, we assume $\sigma_{t-1}^2$ is known. 

- $\mathbf{X}_{t-1}$ represents the lower resilution image, and $\mathbf{X}_t$ represents the higher resolution image. The number of voxels in $\mathbf{X}_t$ is 8 times the number of voxels in $\mathbf{X}_{t-1}$. 

- From the procressing step, we have obtained $\mathbf{X}_t$ from real data. To obatin its corresponding lower resolution image $\mathbf{X}_{t-1}$, we downsample $\mathbf{X}_t$ to half size in each dimension. We repeat this procedure to obtain $\mathbf{X}_{t-1}$ to $\mathbf{X}_{0}$.

- We use a 3D UNet model to learn the function $g$, using the processed image $\mathbf{X}_{t-1}$ and generated Gaussian noise $\boldsymbol{\epsilon}_{t-1}$ as the input, and $\mathbf{X}_t$ as the output. Currently, the training process is to minimize the mean squared error loss. 

- We still want the tensor size to be the same. Therefore, for downsampled image, we use padding to make sure its size is still $3\times 2^T\times 2^T\times 2^T$. 

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from scipy.spatial import cKDTree
from scipy.ndimage import zoom  
import os
import math
from tqdm.auto import tqdm  
from pathos.multiprocessing import ProcessingPool as Pool
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
       

# Define the SuperResolutionDataset
class SuperResolutionDataset(Dataset):
    def __init__(self, images, T, sigma_t_list):
        self.images = images
        self.T = T
        self.sigma_t_list = sigma_t_list
        self.data_pairs = []
        self.prepare_data()

    def gaussian_kernel_1d(self, kernel_size, sigma):
        # Create a 1D Gaussian kernel
        x = torch.arange(kernel_size) - kernel_size // 2
        kernel = torch.exp(-0.5 * (x / sigma) ** 2)
        kernel = kernel / kernel.sum()
        return kernel

    def gaussian_blur_3d(self, x, kernel_size=5, sigma=1):
        # x: Tensor of shape (C, D, H, W)
        device = x.device
        x = x.unsqueeze(0)  # Add batch dimension
        N, C, D, H, W = x.shape

        # Adjust kernel_size if necessary
        max_kernel_size = min(kernel_size, D, H, W)
        if max_kernel_size % 2 == 0:
            max_kernel_size -= 1  # Ensure it's odd
        if max_kernel_size < 1:
            x_blur = x
        else:
            # Create 1D Gaussian kernel
            kernel = self.gaussian_kernel_1d(max_kernel_size, sigma).to(device)
            # Create 3D Gaussian kernel
            kernel_3d = kernel[:, None, None] * kernel[None, :, None] * kernel[None, None, :]
            kernel_3d = kernel_3d / kernel_3d.sum()
            # Reshape kernel for conv3d
            kernel_3d = kernel_3d.view(1, 1, max_kernel_size, max_kernel_size, max_kernel_size)
            kernel_3d = kernel_3d.repeat(C, 1, 1, 1, 1)
            padding = max_kernel_size // 2
            # Adjust padding for each dimension
            pad_D = min(padding, D - 1)
            pad_H = min(padding, H - 1)
            pad_W = min(padding, W - 1)
            x_padded = F.pad(x, (pad_W, pad_W, pad_H, pad_H, pad_D, pad_D), mode='reflect')
            x_blur = F.conv3d(x_padded, kernel_3d, groups=C)
        x_blur = x_blur.squeeze(0)  
        return x_blur

    def interpolate_3d(self, tensor, **kwargs):
        # tensor: (C, D, H, W)
        tensor = tensor.unsqueeze(0)  
        tensor_interp = F.interpolate(tensor, **kwargs)
        tensor_interp = tensor_interp.squeeze(0)  
        return tensor_interp

    def generate_downsampled_images(self, image):
        # image: Tensor of shape (C, D, H, W)
        # Generate downsampled images from X_T to X_0
        X_t_list = []
        X_t = image  # Start with the original image at highest resolution
        X_t_list.append(X_t)
        for t in range(self.T, 0, -1):
            # Downsample X_t to half size in each dimension
            X_t_down = self.interpolate_3d(X_t, scale_factor=0.5, mode='trilinear', align_corners=False, recompute_scale_factor=True)
            # Apply Gaussian blur to X_t_down
            X_t_blur = self.gaussian_blur_3d(X_t_down)
            # Upsample back to original size
            original_size = image.shape[1:]  # (D, H, W)
            X_t_blur_upsampled = self.interpolate_3d(X_t_blur, size=original_size, mode='trilinear', align_corners=False)
            X_t_list.insert(0, X_t_blur_upsampled)  
            X_t = X_t_down  # Update X_t for the next iteration
        return X_t_list  # X_t_list[0] corresponds to t=0, X_t_list[T] corresponds to t=T

    def prepare_data(self):
        # Create training pairs (X_{t-1} + ε_{t-1}, X_t) for t = 1 to T
        for image in self.images:
            X_t_list = self.generate_downsampled_images(image)
            for t in range(1, self.T + 1):
                X_t = X_t_list[t]
                X_t_minus_1 = X_t_list[t - 1]
                sigma_t_minus_1 = self.sigma_t_list[t - 1]
                
                epsilon_t_minus_1 = torch.randn_like(X_t_minus_1) * (sigma_t_minus_1 ** 0.5)
                X_t_minus_1_noisy = X_t_minus_1 + epsilon_t_minus_1
               
                assert X_t_minus_1_noisy.shape == X_t.shape, f"Shape mismatch: X_t_minus_1_noisy {X_t_minus_1_noisy.shape}, X_t {X_t.shape}"
                self.data_pairs.append((X_t_minus_1_noisy, X_t, t))
        # Now self.data_pairs contains all training pairs

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        X_input, X_target, t = self.data_pairs[idx]
        return X_input, X_target, t

# Define the 3D UNet model 
class UNet3D(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, init_features=32):
        super(UNet3D, self).__init__()
        features = init_features
        self.encoder1 = UNet3D._block(in_channels, features)
        self.pool1 = nn.MaxPool3d(kernel_size=2)
        self.encoder2 = UNet3D._block(features, features * 2)
        self.pool2 = nn.MaxPool3d(kernel_size=2)
        self.encoder3 = UNet3D._block(features * 2, features * 4)
        

        self.bottleneck = UNet3D._block(features * 4, features * 8)

        self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet3D._block(features * 8, features * 4)
        self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet3D._block(features * 4, features * 2)
        self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet3D._block(features * 2, features)

        self.conv = nn.Conv3d(features, out_channels, kernel_size=1)

    @staticmethod
    def _block(in_channels, features):
        return nn.Sequential(
            nn.Conv3d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
            nn.Conv3d(features, features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        enc1 = self.encoder1(x)  # [N, features, D, H, W]
        enc2 = self.encoder2(self.pool1(enc1))  # [N, features*2, D/2, H/2, W/2]
        enc3 = self.encoder3(self.pool2(enc2))  # [N, features*4, D/4, H/4, W/4]

        bottleneck = self.bottleneck(enc3)  # [N, features*8, D/4, H/4, W/4]

        dec3 = self.upconv3(bottleneck)  # [N, features*4, D/2, H/2, W/2]
        # Adjust dec3 size if necessary
        if dec3.shape[2:] != enc3.shape[2:]:
            dec3 = F.interpolate(dec3, size=enc3.shape[2:], mode='trilinear', align_corners=False)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)  # [N, features*2, D, H, W]
        if dec2.shape[2:] != enc2.shape[2:]:
            dec2 = F.interpolate(dec2, size=enc2.shape[2:], mode='trilinear', align_corners=False)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)  # [N, features, 2D, 2H, 2W]
        if dec1.shape[2:] != enc1.shape[2:]:
            dec1 = F.interpolate(dec1, size=enc1.shape[2:], mode='trilinear', align_corners=False)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return self.conv(dec1)

# Custom collate function
def custom_collate_fn(batch):
    X_inputs, X_targets, ts = zip(*batch)
    # Determine the maximum spatial dimensions in the batch
    max_C = max(x_input.shape[0] for x_input in X_inputs)
    max_D = max(x_input.shape[1] for x_input in X_inputs)
    max_H = max(x_input.shape[2] for x_input in X_inputs)
    max_W = max(x_input.shape[3] for x_input in X_inputs)

    # Pad all tensors to the maximum size
    X_inputs_padded = []
    X_targets_padded = []
    for x_input, x_target in zip(X_inputs, X_targets):
        padding_input = (
            0, max_W - x_input.shape[3],  # Width padding
            0, max_H - x_input.shape[2],  # Height padding
            0, max_D - x_input.shape[1],  # Depth padding
        )
        padding_target = (
            0, max_W - x_target.shape[3],
            0, max_H - x_target.shape[2],
            0, max_D - x_target.shape[1],
        )
        x_input_padded = F.pad(x_input, padding_input, mode='constant', value=0)
        x_target_padded = F.pad(x_target, padding_target, mode='constant', value=0)
        X_inputs_padded.append(x_input_padded)
        X_targets_padded.append(x_target_padded)

    X_inputs_batch = torch.stack(X_inputs_padded)
    X_targets_batch = torch.stack(X_targets_padded)
    ts_batch = torch.tensor(ts)
    return X_inputs_batch, X_targets_batch, ts_batch

# Training function
def train(model, dataloader, optimizer, criterion, device, num_epochs=10):
       model.train()
       for epoch in range(num_epochs):
           epoch_loss = 0
           print(f"Epoch [{epoch+1}/{num_epochs}]")
           
           with tqdm(total=len(dataloader), desc=f"Training Epoch {epoch+1}", unit="batch") as pbar:
               for X_input, X_target, t in dataloader:
                   X_input = X_input.to(device)
                   X_target = X_target.to(device)
                   # Forward pass
                   outputs = model(X_input)
                   # Ensure outputs and X_target have the same shape
                   if outputs.shape != X_target.shape:
                       X_target = F.interpolate(X_target, size=outputs.shape[2:], mode='trilinear', align_corners=False)
                   loss = criterion(outputs, X_target)
                   # Backward pass and optimization
                   optimizer.zero_grad()
                   loss.backward()
                   optimizer.step()
                   epoch_loss += loss.item()
                   pbar.update(1)
           avg_loss = epoch_loss / len(dataloader)
           print(f"Average Loss: {avg_loss:.4f}")
            
# Main code

if __name__ == "__main__":
    
    save_dir = 'tensor_batches'

    # List all tensor batch files
    batch_files = [
        os.path.join(save_dir, f)
        for f in os.listdir(save_dir)
        if f.endswith('.pt')
    ]
    batch_files.sort()

    images = []

    print("Loading images from processed tensors...")
    for batch_file in tqdm(batch_files, desc="Loading batches"):
        batch = torch.load(batch_file)  # Each batch is a list of (tensor, t, idx)
        for data in batch:
            tensor, t, idx = data
            # tensor is of shape (D, H, W, C), need to permute to (C, D, H, W)
            tensor = tensor.permute(3, 0, 1, 2)
            images.append(tensor)


    T = 4  # Maximum resolution level
    sigma_t_list = [0.1] * T  # Prespecified \sigma_t^2 for each t from 0 to T-1

    dataset = SuperResolutionDataset(images, T, sigma_t_list)
    batch_size = 32
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

    model = UNet3D(in_channels=3, out_channels=3)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    num_epochs = 50  # Adjust the number of epochs as needed
    train(model, dataloader, optimizer, criterion, device, num_epochs)
    
    # Save the trained model
    model_path = 'unet3d_super_resolution.pth'
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to '{model_path}'")

### Step 4: Sampling

Using the learned super-resolution function $\hat{g}$, we can improve the resolution of any local image. 

- For a local image $\mathbf{X}_t$, we can improve its resolution by 8 times through $\mathbf{X}_{t+1}=\hat{g}(\mathbf{X}_t+\boldsymbol{\epsilon}_t)$, where $\boldsymbol{\epsilon}_t$ are generated from $N(0,\sigma_t^2\mathbf{I})$.  

- Repeat this process $T-t$ times, we can obtain the highest resolution image $\mathbf{X}_T$.

- We can use the improved image $\mathbf{X}_T$ to map back to the original measurements. The general idea is, we should have more information in $\mathbf{X}_T$ than the original $\mathbf{X}_t$. Currently, we use the central voxel value of $\mathbf{X}_T$ to map back to the original measurement. 

In [None]:
def sample_super_resolution(model, tensor, t, device, T=4, sigma_t_list=[0.1]*4):
    
    model.eval()
    tensor = tensor.to(device)

    # Permute the tensor dimensions from [D, H, W, C] to [C, D, H, W]
    if tensor.dim() == 4:
        tensor = tensor.permute(3, 0, 1, 2)
    elif tensor.dim() == 5:
        tensor = tensor.permute(0, 4, 1, 2, 3)

    if t == T:
        # Permute back before returning
        return tensor.cpu().permute(1, 2, 3, 0)
    else:
        with torch.no_grad():
            for res_level in range(t, T):
                # Add noise ε_t with variance σ_t^2
                sigma_t = sigma_t_list[res_level]
                epsilon_t = torch.randn_like(tensor) * (sigma_t ** 0.5)
                X_input = tensor + epsilon_t

                # Apply the model to X_input
                X_input = X_input.unsqueeze(0)  # Add batch dimension
                tensor = model(X_input).squeeze(0)  # Remove batch dimension

        # Permute back to original dimension order before returning
        return tensor.cpu().permute(1, 2, 3, 0)

if __name__ == '__main__':
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load the trained model
    model_path = 'unet3d_super_resolution.pth'
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file '{model_path}' not found.")
    model = UNet3D(in_channels=3, out_channels=3)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Input data directory
    input_dir = 'tensor_batches'  # Directory containing your saved data batches
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input data directory '{input_dir}' not found.")

    # Output directory
    output_dir = 'upsampled_batches'
    os.makedirs(output_dir, exist_ok=True)

    # Get list of all .pt files in input_dir
    batch_files = [f for f in os.listdir(input_dir) if f.endswith('.pt')]
    batch_files.sort()

    # Process each batch file
    for batch_file in tqdm(batch_files, desc='Processing batches'):
        batch_path = os.path.join(input_dir, batch_file)
        # Load the batch data
        batch_data = torch.load(batch_path)  # Should be a list of (tensor, t, point_idx)

        upsampled_batch = []

        for data in batch_data:
            tensor, t, point_idx = data
            # Apply the sampling function
            upsampled_tensor = sample_super_resolution(model, tensor, t, device)
            # Append the upsampled tensor and point index
            upsampled_batch.append((upsampled_tensor, point_idx))

        # Save the upsampled batch
        output_batch_path = os.path.join(output_dir, batch_file)
        torch.save(upsampled_batch, output_batch_path)

    print("Upsampled data saved.")
    
def extract_center_voxel_channels(batch_dir='upsampled_batches'):
    
    center_voxel_channels = []

    # Get list of all .pt files in the batch directory
    batch_files = [f for f in os.listdir(batch_dir) if f.endswith('.pt')]
    batch_files.sort()

    for batch_file in tqdm(batch_files, desc='Processing batches'):
        batch_path = os.path.join(batch_dir, batch_file)
        # Load the batch data
        batch_data = torch.load(batch_path)  # Should be a list of (tensor, point_idx)

        for data in batch_data:
            upsampled_tensor, point_idx = data  # Assuming data is (tensor, point_idx)
            # upsampled_tensor shape: [D, H, W, C], expected to be [16, 16, 16, 3]

            # Verify tensor shape
            if upsampled_tensor.shape != (16, 16, 16, 3):
                raise ValueError(f"Unexpected tensor shape: {upsampled_tensor.shape}")

            # Get the center voxel indices
            center_idx = (7, 7, 7)  # Zero-based indexing

            # Extract the channels at the center voxel
            center_channels = upsampled_tensor[center_idx[0], center_idx[1], center_idx[2], :]  # Shape: [3]

            # Append to the list
            center_voxel_channels.append((center_channels, point_idx))

    return center_voxel_channels


if __name__ == '__main__':
    # Directory containing the upsampled batches
    batch_dir_up = 'upsampled_batches'
    batch_dir_ori = 'tensor_batches'

    # Extract the center voxel channels
    center_voxel_data_up = extract_center_voxel_channels(batch_dir_up)
    center_voxel_data_ori = extract_center_voxel_channels(batch_dir_ori)
    
    channels_list_up = [channels.numpy() for channels, _ in center_voxel_data_up]
    channels_list_ori = [channels.numpy() for channels, _ in center_voxel_data_ori]
    
    channels_array_up = np.array(channels_list_up)
    channels_array_ori = np.array(channels_list_ori)

    np.savez('center_voxel_data.npz', channels_up=channels_array_up, channels_ori=channels_array_ori)



### Things to consider...

- Instead of building a local image for every measurement, we may consider building enough local images from a set of $\epsilon$-balls that covers the entire space. 

- We can consider using the engression framework rather than the MSE loss. 

- When mapping the higher resolution image back to the original space, we can use interpolation to fit a function first, and then evaluate the function at the spatial coordinates to obtain the corresponding values. 