In [1]:
# pip install torch 

In [2]:
# pip install vtk

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os 
from vtk import *
from vtk.util import numpy_support
from vtk import vtkXMLImageDataReader


In [4]:
def compute_PSNR(arrgt,arr_recon):
    try:
        diff = arrgt - arr_recon
        sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
        if(np.mean(diff**2) == 0):
            raise ZeroDivisionError("dividing by zero, cannot calculate psnr")
        snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
        return snr
    except ZeroDivisionError as err:
        return str(err)

def get_numpy_array_from_vtk_image_data(vtk_image_data):
    point_data = vtk_image_data.GetPointData()
    array = point_data.GetArray(0)
    if array is None:
        raise ValueError("No array found in vtkImageData.")
    
    numpy_array = numpy_support.vtk_to_numpy(array)
    dims = vtk_image_data.GetDimensions()  # Gets the dimensions of the vtkImageData
    numpy_array = numpy_array.reshape(dims[1], dims[0], dims[2]) # Reshape according to VTK's dimension order
    return numpy_array


def read_vti(filename):
    reader = vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()


def writeVti(data, filename):
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(data)
    writer.Write()


def createVtkImageData(origin, dimensions, spacing):
    localDataset = vtkImageData()
    localDataset.SetOrigin(origin)
    localDataset.SetDimensions(dimensions)
    localDataset.SetSpacing(spacing)
    return localDataset



In [5]:
vtifile = read_vti('Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
try:
    data = get_numpy_array_from_vtk_image_data(vtifile)
    print(data.shape)
except ValueError as e:
    print(e)

(250, 250, 50)


In [6]:
data = data.astype(np.float32)
# data = (data - data.min()) / (data.max() - data.min()) # do it -1 to 1 
data = 2 * ((data - data.min()) / (data.max() - data.min())) - 1 
#try standardscalar as well 
four_d_tensor = torch.from_numpy(data).float() 
input_tensor = four_d_tensor.unsqueeze(0)

block_dims = (125, 125, 25) 
# block_dims = (5, 5, 5) 

In [7]:
# flat_data_array = data.flatten()

# # Convert numpy array to VTK array
# vtk_array = numpy_support.numpy_to_vtk(flat_data_array)
# vtk_array.SetName("NormalizedData")  # Optional: set the array name for identification in ParaView



# rawdata = read_vti('Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
# array = rawdata.GetPointData().GetArray(0)

# dim = rawdata.GetDimensions()
# spacing = rawdata.GetSpacing()
# origin = rawdata.GetOrigin()

# # Get the dimensions, spacing, and origin from the original VTI file
# # dim = vtifile.GetDimensions()
# # spacing = vtifile.GetSpacing()
# # origin = vtifile.GetOrigin()

# # Create a new VTK ImageData object
# new_data = createVtkImageData(origin, dim, spacing)

# # Add the VTK array to the ImageData object
# new_data.GetPointData().SetScalars(vtk_array)

# # Write to a VTI file
# writeVti(new_data, 'normalized_data.vti')

In [8]:
print(input_tensor.shape)

torch.Size([1, 250, 250, 50])


In [9]:
class Conv3DAutoencoder(nn.Module):
    def __init__(self):
        super(Conv3DAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),  # Output: (16, 63, 63, 13)
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1), # Output: (32, 32, 32, 7)
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1), # Output: (64, 16, 16, 4)
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1, padding_mode='zeros'), # Output: (128, 8, 8, 2)
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8 * 2, 1024),
            nn.ReLU(),
            # nn.Dropout(0.4),
            nn.Linear(1024, 512),
            # nn.ReLU() #dont do this 
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            # nn.Dropout(0.8), #remove dropout
            nn.Linear(1024, 128 * 8 * 8 * 2),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8, 2)),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=(1,1,0)), 
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=0),
            nn.BatchNorm3d(16), 
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=0), 
            # nn.Sigmoid() #use tanh
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Conv3DAutoencoder()
model.eval()


Conv3DAutoencoder(
  (encoder): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (10): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=16384, out_features=1024, bias=True)
    (14): ReLU()
    (15): Linear(in_features=1024, out_features=512, bias=True)
  )
  (

In [10]:
def divide_into_blocks(data, block_dims): # make blocks 5x5x5 

    if not isinstance(data, torch.Tensor):
        data = torch.tensor(data, dtype=torch.float)

    blocks = []
    height_step, width_step, depth_step = block_dims

    for h in range(0, data.shape[1], height_step): 
        for w in range(0, data.shape[2], width_step):  
            for d in range(0, data.shape[3], depth_step): 
                block = data[:, h:h + height_step, w:w + width_step, d:d + depth_step]
                blocks.append(block)

    return blocks


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Conv3DAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
blocks = divide_into_blocks(input_tensor, block_dims)
model.train() 
#channel 32,5,5,5 4d tensor 
num_epochs = 50 
#pytorch dataset #batch size add
for epoch in range(num_epochs):
    total_loss = 0
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        optimizer.zero_grad()
        output = model(block_tensor)
        loss = criterion(output, block_tensor)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(blocks)}')

  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


Epoch 1, Loss: 0.3905640672892332
Epoch 2, Loss: 0.2491377331316471
Epoch 3, Loss: 0.18733981810510159
Epoch 4, Loss: 0.14240751694887877
Epoch 5, Loss: 0.10894173244014382
Epoch 6, Loss: 0.08429577434435487
Epoch 7, Loss: 0.06605103984475136
Epoch 8, Loss: 0.053325130604207516
Epoch 9, Loss: 0.04430016363039613
Epoch 10, Loss: 0.03791127190925181
Epoch 11, Loss: 0.034095620503649116
Epoch 12, Loss: 0.031217836309224367
Epoch 13, Loss: 0.029366314760409296
Epoch 14, Loss: 0.027644895599223673
Epoch 15, Loss: 0.02648532041348517
Epoch 16, Loss: 0.025625170208513737
Epoch 17, Loss: 0.024818112491630018
Epoch 18, Loss: 0.023944206070154905
Epoch 19, Loss: 0.02331663272343576
Epoch 20, Loss: 0.022724182228557765
Epoch 21, Loss: 0.022300655720755458
Epoch 22, Loss: 0.022583245998248458
Epoch 23, Loss: 0.022140466258861125
Epoch 24, Loss: 0.021798150381073356
Epoch 25, Loss: 0.020468764239922166
Epoch 26, Loss: 0.020887388673145324
Epoch 27, Loss: 0.021873077377676964
Epoch 28, Loss: 0.01820

In [13]:
model.eval()
reconstructed_blocks = []
with torch.no_grad():
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        output = model(block_tensor).cpu().numpy()
        reconstructed_blocks.append(output.squeeze(0))


  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


In [14]:
import numpy as np
import torch

def reassemble_blocks(blocks, original_shape, block_dims):
    reassembled_data = np.zeros(original_shape)

    index = 0

    for h in range(0, original_shape[1], block_dims[0]):  # Height
        for w in range(0, original_shape[2], block_dims[1]):  # Width
            for d in range(0, original_shape[3], block_dims[2]):  # Depth
                # Ensuring the block is inserted into the correct slice
                if isinstance(blocks[index], torch.Tensor):
                    block_data = blocks[index].numpy()  # Convert to numpy if it's a tensor
                else:
                    block_data = blocks[index]
                
                reassembled_data[:, h:h + block_dims[0], w:w + block_dims[1], d:d + block_dims[2]] = block_data
                index += 1

    return reassembled_data


In [15]:
reassembled_data = reassemble_blocks(reconstructed_blocks, input_tensor.shape, block_dims)

original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)
reassembled_tensor = torch.tensor(reassembled_data, dtype=torch.float).to(device)

final_loss = criterion(reassembled_tensor,input_tensor)
print(f'Final MSE Loss on Entire Data: {final_loss.item()}')

Final MSE Loss on Entire Data: 0.02266939915716648


  original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)


In [16]:
print(reassembled_data.shape)
reconstructed_data = reassembled_data[0, :, :, :]

psnr_score = compute_PSNR(data,reconstructed_data)
print(psnr_score)

(1, 250, 250, 50)
22.46619937273561


In [17]:
flat_data_array = reconstructed_data.flatten()
vtk_array = numpy_support.numpy_to_vtk(flat_data_array)


data = read_vti('Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
array = data.GetPointData().GetArray(0)

dim = data.GetDimensions()
spacing = data.GetSpacing()
origin = data.GetOrigin()


new_data = createVtkImageData(origin, dim, spacing)
new_data.GetPointData().AddArray(vtk_array)

writeVti(new_data, 'out.vti')