In [69]:
# pip install torch 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os 

In [None]:
directory_path = '/Users/manasvijain/Desktop/Autoencoder/matrices' 
files = [f"Matrix_{i}.txt" for i in range(50)]
matrices = []

for file in files:
    file_path = os.path.join(directory_path, file) 
    matrix = np.loadtxt(file_path)  
    if matrix.shape != (250, 250):
        raise ValueError(f"Expected matrix size 250x250, but got {matrix.shape} in file {file}")
    matrices.append(matrix)  

data = np.stack(matrices, axis=2)
data = data.astype(np.float32)
data = (data - data.min()) / (data.max() - data.min())

# print(data.shape)  # Should be (250, 250, 50)
# print(data.dtype)

four_d_tensor = torch.from_numpy(data).float() 
input_tensor = four_d_tensor.unsqueeze(0)

# print("Corrected input tensor shape:", input_tensor.shape)

block_dims = (125, 125, 25) 

In [None]:
class Conv3DAutoencoder(nn.Module):
    def __init__(self):
        super(Conv3DAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),  # Output: (16, 63, 63, 13)
            nn.ReLU(),
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1), # Output: (32, 32, 32, 7)
            nn.ReLU(),
            nn.BatchNorm3d(32),
            nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1), # Output: (64, 16, 16, 4)
            nn.ReLU(),
            nn.BatchNorm3d(64),
            nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1, padding_mode='zeros'), # Output: (128, 8, 8, 2)
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8 * 2, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 128 * 8 * 8 * 2),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8, 2)),
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=(1,1,0)), 
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=0), 
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=0), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Conv3DAutoencoder()
model.eval()

# dummy_input = torch.rand(1, 1, 125, 125, 25)

# with torch.no_grad():
#     output = model(dummy_input)

# print("Output shape:", output.shape)


Output shape: torch.Size([1, 1, 125, 125, 25])


In [None]:
def divide_into_blocks(data, block_dims):

    if not isinstance(data, torch.Tensor):
        data = torch.tensor(data, dtype=torch.float)

    blocks = []
    height_step, width_step, depth_step = block_dims

    for h in range(0, data.shape[1], height_step): 
        for w in range(0, data.shape[2], width_step):  
            for d in range(0, data.shape[3], depth_step): 
                block = data[:, h:h + height_step, w:w + width_step, d:d + depth_step]
                blocks.append(block)

    return blocks

# Example usage:
# data_size = (1, 250, 250, 50)  
# block_dims = (125, 125, 25)    
# data = torch.randn(data_size) 
# blocks = divide_into_blocks(data, block_dims)

# print(f"Number of blocks: {len(blocks)}")
# print(f"Block shape: {blocks[0].shape}")


In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Conv3DAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
blocks = divide_into_blocks(input_tensor, block_dims)
model.train()
num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        optimizer.zero_grad()
        output = model(block_tensor)
        loss = criterion(output, block_tensor)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(blocks)}')

  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


Epoch 1, Loss: 0.04297915659844875
Epoch 2, Loss: 0.02842120884452015
Epoch 3, Loss: 0.026696497341617942
Epoch 4, Loss: 0.01599616667954251
Epoch 5, Loss: 0.010846497054444626
Epoch 6, Loss: 0.005475501689943485
Epoch 7, Loss: 0.006194891931954771
Epoch 8, Loss: 0.00804750838142354
Epoch 9, Loss: 0.010682100342819467
Epoch 10, Loss: 0.006616594328079373
Epoch 11, Loss: 0.004979448225640226
Epoch 12, Loss: 0.008494828172842972
Epoch 13, Loss: 0.01962474489118904
Epoch 14, Loss: 0.011987173173110932
Epoch 15, Loss: 0.021181277406867594
Epoch 16, Loss: 0.008708809633390047
Epoch 17, Loss: 0.008870898120221682
Epoch 18, Loss: 0.00538771537685534
Epoch 19, Loss: 0.004541985283140093
Epoch 20, Loss: 0.005100293812574819
Epoch 21, Loss: 0.00818628445995273
Epoch 22, Loss: 0.004517680070421193
Epoch 23, Loss: 0.004938029138429556
Epoch 24, Loss: 0.00959273154148832
Epoch 25, Loss: 0.005692511971574277
Epoch 26, Loss: 0.007552204129751772
Epoch 27, Loss: 0.007788841707224492
Epoch 28, Loss: 0.

In [77]:
model.eval()
reconstructed_blocks = []
with torch.no_grad():
    for block in blocks:
        block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)
        output = model(block_tensor).cpu().numpy()
        reconstructed_blocks.append(output.squeeze(0))


  block_tensor = torch.tensor(block, dtype=torch.float).unsqueeze(0).to(device)


In [None]:
reassembled_data = reassemble_blocks(reconstructed_blocks, input_tensor.shape, block_dims)

original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)
reassembled_tensor = torch.tensor(reassembled_data, dtype=torch.float).to(device)

final_loss = criterion(reassembled_tensor,input_tensor)
print(f'Final MSE Loss on Entire Data: {final_loss.item()}')

Final MSE Loss on Entire Data: 0.0029473984614014626


  original_tensor = torch.tensor(input_tensor, dtype=torch.float).to(device)


In [None]:
vtk_array = numpy_support.numpy_to_vtk(reassembled_data)

data = read_vti('test.vti')
array = data.GetPointData().GetArray(0)

dim = data.GetDimensions()
spacing = data.GetSpacing()
origin = data.GetOrigin()

new_data = createVtkImageData(origin, dim, spacing)
new_data.GetPointData().AddArray(vtk_array)

writeVti(new_data, 'out.vti')

psnr_score = compute_PSNR(data,new_data)
print(psnr_score)