In [19]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import CNN_VAE
import importlib
import numpy as np
# Reload the module to ensure changes are recognized
importlib.reload(CNN_VAE)

# Import the updated class and custom dataset
from CNN_VAE import ConvolVariatinalAutoEncoder, CustomImageDataset


In [20]:

# Define the model parameters

BATCH_SIZE = 2048
NUM_EPOCHS = 1000

Z_DIMS = [8]
H_DIMS = [32]
im_size =(256, 256)
# Define the transformation
transform = transforms.Compose([
    transforms.Resize(im_size),  # Resize images to 128x128
    transforms.ToTensor()
])

# Load the dataset
dataset = CustomImageDataset(root_dir='../dataset/Waves4', transform=transform)

# Create a DataLoader
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

# Move the images to the same device as the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

for h_dim in H_DIMS:
    for z_dim in Z_DIMS:
            
        print(f'Training model with z_dim={z_dim} and h_dim={h_dim}...')
        # Create an instance of the model
        model = ConvolVariatinalAutoEncoder(input_dim=1, h_dim=h_dim, z_dim=z_dim, output_channels=1, im_size=im_size)

        # Load the saved state dictionary
        model.load_state_dict(torch.load('./Modelos_VAE/cnn_vae_model_Z{}_H{}_epoch_6000.pth'.format(z_dim, h_dim)))

        # Set the model to evaluation mode
        model.eval()

        model = model.to(device)

        # Pass the images through the model
        for x, _ in train_loader:  
            x = x.to(device)
            
            x_reconstructed, mu, logvar = model(x)
            # Move the reconstructed images back to CPU for saving
            x_reconstructed = x_reconstructed.cpu().detach().numpy()
            error = x_reconstructed-x.cpu().detach().numpy()
            # Transpose the image to (height, width, channels) format for saving
            for j in range(x_reconstructed.shape[0]):
                print(x_reconstructed.max())
                image = x_reconstructed[j].transpose(1, 2, 0)
                error_image = error[j].transpose(1, 2, 0)
                # Convert the image to a NumPy array
                image = (image * 255).astype(np.uint8)
                error_image = (error_image * 255).astype(np.uint8)
                # Convert the NumPy array to a PIL image
                
                image_pil = Image.fromarray(np.squeeze(image))
                error_image = Image.fromarray(np.squeeze(error_image))
                # Save the image as a PNG file
                try: 
                    os.mkdir('Modelos_VAE/Imagens_sample/Model_Z{}_H{}'.format(z_dim, h_dim))
                    os.mkdir('Modelos_VAE/Error_sample/Model_Z{}_H{}'.format(z_dim, h_dim))
                    
                except: pass
                image_pil.save('Modelos_VAE/Imagens_sample/Model_Z{}_H{}/Im{}.png'.format(z_dim, h_dim, j))
                error_image.save('Modelos_VAE/Error_sample/Model_Z{}_H{}/Im{}.png'.format(z_dim, h_dim, j))
                
# # Select the first image in the batch
# image = x_reconstructed[1]

# # Transpose the image to (height, width, channels) format for saving
# image = image.transpose(1, 2, 0)


# import numpy as np
# # Convert the image to a NumPy array
# image = (image * 255).astype(np.uint8)

# # Convert the NumPy array to a PIL image
# image_pil = Image.fromarray(image)

# # Save the image as a PNG file
# image_pil.save('image_2.png')


Training model with z_dim=8 and h_dim=32...
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261
1.0383261


In [22]:
x_reconstructed.max()

1.0292159

In [23]:
image_pil = Image.fromarray(np.squeeze(image))

In [24]:
np.squeeze(image)

array([[ 99, 101, 100, ...,  89,  87,  85],
       [ 99, 101, 102, ...,  88,  86,  84],
       [ 98, 101,  99, ...,  89,  87,  86],
       ...,
       [ 80,  83,  77, ...,  97,  94,  91],
       [ 79,  80,  80, ...,  91,  98,  79],
       [ 81,  83,  84, ...,  92,  93,  65]], dtype=uint8)