In [24]:
# pip install torch 

In [25]:
# pip install vtk

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os 
import vtk
from vtk import *
from vtk.util import numpy_support
from vtk import vtkXMLImageDataReader
from vtk import vtkImageGradient, vtkImageMagnitude
from vtkmodules.vtkCommonDataModel import vtkImageData
from vtkmodules.util import numpy_support


In [27]:
def compute_PSNR(arrgt,arr_recon):
    try:
        diff = arrgt - arr_recon
        sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
        if(np.mean(diff**2) == 0):
            raise ZeroDivisionError("dividing by zero, cannot calculate psnr")
        snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
        return snr
    except ZeroDivisionError as err:
        return str(err)

def read_vti(filename):
    reader = vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()


def writeVti(data, filename):
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(data)
    writer.Write()


def createVtkImageData(origin, dimensions, spacing):
    localDataset = vtkImageData()
    localDataset.SetOrigin(origin)
    localDataset.SetDimensions(dimensions)
    localDataset.SetSpacing(spacing)
    return localDataset


In [28]:
def compute_gradient_magnitude(vtk_image_data):
    if not isinstance(vtk_image_data, vtk.vtkImageData):
        raise TypeError("compute_gradient_magnitude requires a vtkImageData object")
    
    if vtk_image_data.GetPointData().GetScalars() is None:
        raise ValueError("VTK ImageData has no scalars set. Cannot compute gradients.")

    gradient_filter = vtk.vtkImageGradient()
    gradient_filter.SetInputData(vtk_image_data)
    gradient_filter.SetDimensionality(3)
    gradient_filter.Update()

    magnitude_filter = vtk.vtkImageMagnitude()
    magnitude_filter.SetInputConnection(gradient_filter.GetOutputPort())
    magnitude_filter.Update()

    return magnitude_filter.GetOutput()

def get_numpy_array_from_vtk_image_data(vtk_image_data):
    point_data = vtk_image_data.GetPointData()
    array = point_data.GetScalars()  
    
    if array is None:
        raise ValueError("No scalar array found in vtkImageData.")
    
    numpy_array = numpy_support.vtk_to_numpy(array)
    dims = vtk_image_data.GetDimensions()  
    num_components = array.GetNumberOfComponents()

    expected_size = dims[0] * dims[1] * dims[2] * num_components

    if numpy_array.size != expected_size:
        raise ValueError(f"Shape mismatch! Cannot reshape {numpy_array.size} elements into {dims[2], dims[1], dims[0], num_components}")

    numpy_array = numpy_array.reshape(dims[2], dims[1], dims[0], num_components)  
    return numpy_array.squeeze()  

def tensor_to_vtk_image_data(tensor):
    np_array = tensor.detach().cpu().numpy()

    if np_array.ndim == 4:  
        np_array = np_array[0] 

    vtk_image_data = vtk.vtkImageData()

    dims = np_array.shape  
    vtk_image_data.SetDimensions(dims[2], dims[1], dims[0])  

    flat_array = np_array.flatten(order='F')
    vtk_array = numpy_support.numpy_to_vtk(flat_array, deep=True, array_type=vtk.VTK_FLOAT)
    vtk_image_data.GetPointData().SetScalars(vtk_array)
    
    return vtk_image_data



In [29]:
vtifile = read_vti('/Users/manasvijain/Desktop/Autoencoder/Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
try:
    data = get_numpy_array_from_vtk_image_data(vtifile)
    print(data.shape)
except ValueError as e:
    print(e)

(50, 250, 250)


In [30]:
gradient_magnitude = compute_gradient_magnitude(vtifile)
gradient_magnitude_array = get_numpy_array_from_vtk_image_data(gradient_magnitude)

In [31]:
data = data.astype(np.float32)
data = 2 * ((data - data.min()) / (data.max() - data.min())) - 1 
four_d_tensor = torch.from_numpy(data).float() 
input_tensor = four_d_tensor.unsqueeze(0)

block_dims = (5, 5, 5) 

In [32]:
print(input_tensor.shape)

torch.Size([1, 50, 250, 250])


In [33]:
class Conv3DAutoencoder(nn.Module):
    def __init__(self):
        super(Conv3DAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1), 
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1), 
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Flatten(),
            # nn.Linear(64 * 5 * 5 * 5, 512),
            nn.Linear(64 * 9 * 9 * 9, 512),
            nn.ReLU(),
            nn.Linear(512, 32),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(32,512),
            nn.ReLU(),
            # nn.Linear(512, 64 * 5 * 5 * 5),
            nn.Linear(512, 64 * 9 * 9 * 9),
            nn.ReLU(),
            # nn.Unflatten(1, (64, 5, 5, 5)),
            nn.Unflatten(1, (64, 9, 9, 9)),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1), 
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(16), 
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1), 
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Conv3DAutoencoder()
model.eval()



Conv3DAutoencoder(
  (encoder): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Linear(in_features=46656, out_features=512, bias=True)
    (11): ReLU()
    (12): Linear(in_features=512, out_features=32, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=46656, bias=True)
    (3): ReLU()
    (4): Unflatten

In [34]:
def divide_into_blocks(data, block_dims): # make blocks 5x5x5 

    if not isinstance(data, torch.Tensor):
        data = torch.tensor(data, dtype=torch.float)

    blocks = []
    height_step, width_step, depth_step = block_dims

    for h in range(0, data.shape[1], height_step): 
        for w in range(0, data.shape[2], width_step):  
            for d in range(0, data.shape[3], depth_step): 
                block = data[:, h:h + height_step, w:w + width_step, d:d + depth_step]
                blocks.append(block)

    return blocks


In [35]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Conv3DAutoencoder().to(device)
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Conv3DAutoencoder().to(device)
mse_criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [37]:
import torch
import torch.nn as nn

def custom_loss(output, target, gradient_target, alpha=0.5):

    reconstruction_loss = mse_criterion(output, target)
    
    output_vtk = tensor_to_vtk_image_data(output)
    output_gradient_magnitude = compute_gradient_magnitude(output_vtk)
    output_gradient_magnitude_array = get_numpy_array_from_vtk_image_data(output_gradient_magnitude)


    assert output_gradient_magnitude_array.shape == gradient_target.shape, "Shapes must match"
    output_magnitude = np.linalg.norm(output_gradient_magnitude_array)
    gradient_target_magnitude = np.linalg.norm(gradient_target)
    gradient_loss = np.abs(output_magnitude - gradient_target_magnitude)

    total_loss = alpha * reconstruction_loss + (1 - alpha) * gradient_loss
    return total_loss


In [38]:
from torch.utils.data import Dataset, DataLoader

class BlockDataset(Dataset):
    def __init__(self, blocks):
        self.blocks = [torch.tensor(block, dtype=torch.float) for block in blocks]

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        return self.blocks[idx]


In [39]:
def pad_blocks(data, block_shape, padded_shape, pad_value=0):
    if not isinstance(data, torch.Tensor):
        data = torch.tensor(data, dtype=torch.float)
    if data.ndim == 3:
        data = data.unsqueeze(0)
    pad_h = padded_shape[0] // 2
    pad_w = padded_shape[1] // 2
    pad_d = padded_shape[2] // 2
    padded_data = F.pad(data, (pad_d, pad_d, pad_w, pad_w, pad_h, pad_h), mode='constant', value=pad_value)
    _, H, W, D = data.shape
    n_blocks_h = H // block_shape[0]
    n_blocks_w = W // block_shape[1]
    n_blocks_d = D // block_shape[2]
    blocks_list = []
    for i in range(n_blocks_h):
        for j in range(n_blocks_w):
            for k in range(n_blocks_d):
                center_h = i * block_shape[0] + (block_shape[0] // 2)
                center_w = j * block_shape[1] + (block_shape[1] // 2)
                center_d = k * block_shape[2] + (block_shape[2] // 2)
                c_h = center_h + pad_h
                c_w = center_w + pad_w
                c_d = center_d + pad_d
                block = padded_data[:, 
                    c_h - pad_h : c_h + pad_h + 1,
                    c_w - pad_w : c_w + pad_w + 1,
                    c_d - pad_d : c_d + pad_d + 1
                ]
                blocks_list.append(block)
    return blocks_list

def unpad_block(padded_block, block_dims, padded_shape):
    start_h = padded_shape[0] // 2 - block_dims[0] // 2
    start_w = padded_shape[1] // 2 - block_dims[1] // 2
    start_d = padded_shape[2] // 2 - block_dims[2] // 2
    return padded_block[:, start_h:start_h + block_dims[0],
                           start_w:start_w + block_dims[1],
                           start_d:start_d + block_dims[2]]

def reassemble_blocks(blocks, original_shape, block_dims, padded_shape):
    reassembled_data = np.zeros(original_shape)
    index = 0
    for h in range(0, original_shape[1], block_dims[0]):
        for w in range(0, original_shape[2], block_dims[1]):
            for d in range(0, original_shape[3], block_dims[2]):
                block = blocks[index]
                unpadded = unpad_block(block, block_dims, padded_shape)
                block_data = unpadded.numpy()
                reassembled_data[:, h:h + block_dims[0],
                                  w:w + block_dims[1],
                                  d:d + block_dims[2]] = block_data
                index += 1
    return reassembled_data

In [40]:
input_tensor = torch.rand(1, 250, 250, 50)
block_dims = (5, 5, 5)
padded_shape = (9, 9, 9)
pad_value = 0

blocks = pad_blocks(input_tensor, block_dims, padded_shape, pad_value)
dataset = BlockDataset(blocks)
batch_size = 16  
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


  self.blocks = [torch.tensor(block, dtype=torch.float) for block in blocks]


In [42]:
model.train()
# num_epochs = 50
num_epochs = 10
alpha = 0.5  

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)  
        
        batch_vtk = tensor_to_vtk_image_data(batch)
        batch_gradient_magnitude = compute_gradient_magnitude(batch_vtk)
        batch_gradient_magnitude_array = get_numpy_array_from_vtk_image_data(batch_gradient_magnitude)
        
        optimizer.zero_grad()
        output = model(batch)
        
        loss = custom_loss(output, batch, batch_gradient_magnitude_array, alpha)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.size(0)  
    
    avg_loss = total_loss / len(dataset) 
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

Epoch 1, Loss: 1.0814
Epoch 2, Loss: 1.0087
Epoch 3, Loss: 0.9950


KeyboardInterrupt: 

In [None]:
model.eval()
reconstructed_blocks = []
with torch.no_grad():
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        output = model(block_tensor).cpu().numpy()
        reconstructed_blocks.append(output.squeeze(0))


  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


In [144]:
def reassemble_blocks(blocks, original_shape, block_dims):
    reassembled_data = np.zeros(original_shape)

    index = 0

    for h in range(0, original_shape[1], block_dims[0]):  # Height
        for w in range(0, original_shape[2], block_dims[1]):  # Width
            for d in range(0, original_shape[3], block_dims[2]):  # Depth
                # Ensuring the block is inserted into the correct slice
                if isinstance(blocks[index], torch.Tensor):
                    block_data = blocks[index].numpy()  # Convert to numpy if it's a tensor
                else:
                    block_data = blocks[index]
                
                reassembled_data[:, h:h + block_dims[0], w:w + block_dims[1], d:d + block_dims[2]] = block_data
                index += 1

    return reassembled_data


In [23]:
unpadded_blocks = unpad_blocks(reconstructed_blocks, block_dims, padded_shape)
reassembled_data = reassemble_blocks(unpadded_blocks, input_tensor.shape, block_dims)

original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)
reassembled_tensor = torch.tensor(reassembled_data, dtype=torch.float).to(device)

final_loss = mse_criterion(reassembled_tensor,input_tensor)
print(f'Final MSE Loss on Entire Data: {final_loss.item()}')

SyntaxError: invalid syntax (3179831414.py, line 1)

In [None]:
print(reassembled_data.shape)
reconstructed_data = reassembled_data[0, :, :, :]

psnr_score = compute_PSNR(data,reconstructed_data)
print(psnr_score)

(1, 50, 250, 250)
41.49894151745377


In [None]:
flat_data_array = reconstructed_data.flatten()
vtk_array = numpy_support.numpy_to_vtk(flat_data_array)


data = read_vti('/Users/manasvijain/Desktop/Autoencoder/Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
array = data.GetPointData().GetArray(0)

dim = data.GetDimensions()
spacing = data.GetSpacing()
origin = data.GetOrigin()


new_data = createVtkImageData(origin, dim, spacing)
new_data.GetPointData().AddArray(vtk_array)

writeVti(new_data, 'out100epochs.vti')