In [41]:
# pip install torch 

In [42]:
# pip install vtk

In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os 
from vtk import *
from vtk.util import numpy_support
from vtk import vtkXMLImageDataReader


In [44]:
def compute_PSNR(arrgt,arr_recon):
    try:
        diff = arrgt - arr_recon
        sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
        if(np.mean(diff**2) == 0):
            raise ZeroDivisionError("dividing by zero, cannot calculate psnr")
        snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
        return snr
    except ZeroDivisionError as err:
        return str(err)

def get_numpy_array_from_vtk_image_data(vtk_image_data):
    point_data = vtk_image_data.GetPointData()
    array = point_data.GetArray(0)
    if array is None:
        raise ValueError("No array found in vtkImageData.")
    
    numpy_array = numpy_support.vtk_to_numpy(array)
    dims = vtk_image_data.GetDimensions()  # Gets the dimensions of the vtkImageData
    numpy_array = numpy_array.reshape(dims[1], dims[0], dims[2]) # Reshape according to VTK's dimension order
    return numpy_array


def read_vti(filename):
    reader = vtkXMLImageDataReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()


def writeVti(data, filename):
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(data)
    writer.Write()


def createVtkImageData(origin, dimensions, spacing):
    localDataset = vtkImageData()
    localDataset.SetOrigin(origin)
    localDataset.SetDimensions(dimensions)
    localDataset.SetSpacing(spacing)
    return localDataset



In [45]:
vtifile = read_vti('Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
try:
    data = get_numpy_array_from_vtk_image_data(vtifile)
    print(data.shape)
except ValueError as e:
    print(e)

(250, 250, 50)


In [46]:
data = data.astype(np.float32)
data = 2 * ((data - data.min()) / (data.max() - data.min())) - 1 
four_d_tensor = torch.from_numpy(data).float() 
input_tensor = four_d_tensor.unsqueeze(0)

block_dims = (25, 25, 5) 

In [47]:
print(input_tensor.shape)

torch.Size([1, 250, 250, 50])


In [48]:
import torch
import torch.nn as nn

class Conv3DAutoencoder(nn.Module):
    def __init__(self):
        super(Conv3DAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),  # Output: (16, 13, 13, 3)
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1),  # Output: (32, 7, 7, 2)
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1),  # Output: (64, 4, 4, 1)
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4 * 1, 512),  # Adjust for flattened dimensions
            nn.ReLU(),
            nn.Linear(512, 32),  # Latent vector size
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(32, 512),
            nn.ReLU(),
            nn.Linear(512, 64 * 4 * 4 * 1),
            nn.ReLU(),
            nn.Unflatten(1, (64, 4, 4, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=(0, 0, 1)),  # Output: (32, 7, 7, 2)
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=(0, 0, 0)),  # Output: (16, 13, 13, 5)
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=(0, 0, 0)),  # Output: (1, 25, 25, 5)
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Instantiate the model
model = Conv3DAutoencoder()
model.eval()

# Create a dummy tensor with input size (1, 1, 25, 25, 5)
# dummy_input = torch.randn(1, 1, 25, 25, 5)

# # Pass the dummy tensor through the model
# output = model(dummy_input)

# # Print the output shape to verify
# print(output.shape)  # Expected output: (1, 1, 25, 25, 5)


Conv3DAutoencoder(
  (encoder): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Linear(in_features=1024, out_features=512, bias=True)
    (11): ReLU()
    (12): Linear(in_features=512, out_features=32, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1024, bias=True)
    (3): ReLU()
    (4): Unflatten(d

In [49]:
def divide_into_blocks(data, block_dims): # make blocks 5x5x5 

    if not isinstance(data, torch.Tensor):
        data = torch.tensor(data, dtype=torch.float)

    blocks = []
    height_step, width_step, depth_step = block_dims

    for h in range(0, data.shape[1], height_step): 
        for w in range(0, data.shape[2], width_step):  
            for d in range(0, data.shape[3], depth_step): 
                block = data[:, h:h + height_step, w:w + width_step, d:d + depth_step]
                blocks.append(block)

    return blocks


In [50]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Conv3DAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [51]:
from torch.utils.data import Dataset, DataLoader

class BlockDataset(Dataset):
    def __init__(self, blocks):
        self.blocks = [torch.tensor(block, dtype=torch.float) for block in blocks]

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        return self.blocks[idx]


In [52]:
blocks = divide_into_blocks(input_tensor, block_dims)
dataset = BlockDataset(blocks)

# Define DataLoader with a batch size
batch_size = 16  # You can adjust this as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


  self.blocks = [torch.tensor(block, dtype=torch.float) for block in blocks]


In [53]:
model.train()
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)  # Move batch to device (e.g., GPU)
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.size(0)  # Accumulate total loss for the epoch
    
    avg_loss = total_loss / len(dataset)  # Compute average loss
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')


Epoch 1, Loss: 0.0286
Epoch 2, Loss: 0.0070
Epoch 3, Loss: 0.0054
Epoch 4, Loss: 0.0046
Epoch 5, Loss: 0.0039
Epoch 6, Loss: 0.0029
Epoch 7, Loss: 0.0027
Epoch 8, Loss: 0.0024
Epoch 9, Loss: 0.0021
Epoch 10, Loss: 0.0020
Epoch 11, Loss: 0.0017
Epoch 12, Loss: 0.0016
Epoch 13, Loss: 0.0014
Epoch 14, Loss: 0.0017
Epoch 15, Loss: 0.0014
Epoch 16, Loss: 0.0012
Epoch 17, Loss: 0.0011
Epoch 18, Loss: 0.0010
Epoch 19, Loss: 0.0010
Epoch 20, Loss: 0.0011
Epoch 21, Loss: 0.0010
Epoch 22, Loss: 0.0009
Epoch 23, Loss: 0.0008
Epoch 24, Loss: 0.0007
Epoch 25, Loss: 0.0007
Epoch 26, Loss: 0.0008
Epoch 27, Loss: 0.0006
Epoch 28, Loss: 0.0006
Epoch 29, Loss: 0.0006
Epoch 30, Loss: 0.0006
Epoch 31, Loss: 0.0006
Epoch 32, Loss: 0.0007
Epoch 33, Loss: 0.0006
Epoch 34, Loss: 0.0007
Epoch 35, Loss: 0.0007
Epoch 36, Loss: 0.0006
Epoch 37, Loss: 0.0006
Epoch 38, Loss: 0.0005
Epoch 39, Loss: 0.0005
Epoch 40, Loss: 0.0006
Epoch 41, Loss: 0.0006
Epoch 42, Loss: 0.0005
Epoch 43, Loss: 0.0006
Epoch 44, Loss: 0.00

In [54]:
model.eval()
reconstructed_blocks = []
with torch.no_grad():
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        output = model(block_tensor).cpu().numpy()
        reconstructed_blocks.append(output.squeeze(0))


  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


In [55]:
import numpy as np
import torch

def reassemble_blocks(blocks, original_shape, block_dims):
    reassembled_data = np.zeros(original_shape)

    index = 0

    for h in range(0, original_shape[1], block_dims[0]):  # Height
        for w in range(0, original_shape[2], block_dims[1]):  # Width
            for d in range(0, original_shape[3], block_dims[2]):  # Depth
                # Ensuring the block is inserted into the correct slice
                if isinstance(blocks[index], torch.Tensor):
                    block_data = blocks[index].numpy()  # Convert to numpy if it's a tensor
                else:
                    block_data = blocks[index]
                
                reassembled_data[:, h:h + block_dims[0], w:w + block_dims[1], d:d + block_dims[2]] = block_data
                index += 1

    return reassembled_data


In [56]:
reassembled_data = reassemble_blocks(reconstructed_blocks, input_tensor.shape, block_dims)

original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)
reassembled_tensor = torch.tensor(reassembled_data, dtype=torch.float).to(device)

final_loss = criterion(reassembled_tensor,input_tensor)
print(f'Final MSE Loss on Entire Data: {final_loss.item()}')

Final MSE Loss on Entire Data: 0.0002624118351377547


  original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)


In [57]:
print(reassembled_data.shape)
reconstructed_data = reassembled_data[0, :, :, :]

psnr_score = compute_PSNR(data,reconstructed_data)
print(psnr_score)

(1, 250, 250, 50)
41.830765666748675


In [58]:
flat_data_array = reconstructed_data.flatten()
vtk_array = numpy_support.numpy_to_vtk(flat_data_array)


data = read_vti('Dataset/Pf25.binLE.raw_corrected_2_subsampled.vti')
array = data.GetPointData().GetArray(0)

dim = data.GetDimensions()
spacing = data.GetSpacing()
origin = data.GetOrigin()


new_data = createVtkImageData(origin, dim, spacing)
new_data.GetPointData().AddArray(vtk_array)

writeVti(new_data, 'out3.vti')