In [1]:
import numpy as np
from pathlib import Path

from io import BytesIO
import ipywidgets as widgets
from IPython.display import display

import torch
from torchvision.transforms import ToPILImage

from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.models.nets.image.convae import ConvVAE

from matplotlib import pyplot as plt

from scipy.ndimage import gaussian_gradient_magnitude, laplace

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

print("PyTorch Version:", torch.__version__)
print("CUDA Version:", torch.version.cuda)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Device Count:", torch.cuda.device_count())
print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA Device")



cuda
PyTorch Version: 2.5.1+cu124
CUDA Version: 12.4
CUDA Available: True
CUDA Device Count: 1
CUDA Device Name: NVIDIA GeForce RTX 4090


In [2]:
# f3
train_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/images"
annotation_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/annotations"

# seam-ai (parihaka)
# train_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
# annotation_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"

In [None]:
def normalize_data(data):
    """
    Normaliza os valores dos pixels para o intervalo [0, 1].
    """
    data_min, data_max = data.min(), data.max()
    return (data - data_min) / (data_max - data_min)

def extract_patches(data, patch_size=64, stride=32):
        patches = []
        h, w, _ = data.shape
        for i in range(0, h - patch_size + 1, stride):
            for j in range(0, w - patch_size + 1, stride):
                patch = data[i:i + patch_size, j:j + patch_size]
                patch = patch.transpose(2, 0, 1).astype(np.float32)  # Transpõe para (C, H, W)
                patch = np.expand_dims(patch, axis=0)  # Adiciona uma dimensão no começo
                patch_tensor = torch.from_numpy(patch)  # Converte para tensor PyTorch
                patches.append(patch_tensor)
        return np.array(patches)

train_img_reader = [normalize_data(image) for image in TiffReader(Path(train_path) / "train")] # lendo imagens e normalizando
# train_label_reader = PNGReader(Path(annotation_path) / "train")

patches_img = extract_patches(train_img_reader[0])
sample_img = patches_img[0]
# sample_lab = train_label_reader[0]

print(type(sample_img), sample_img.shape)

<class 'torch.Tensor'> torch.Size([1, 3, 64, 64])


  return np.array(patches)
  return np.array(patches)


In [4]:
checkpoint_path = "checkpoints/convVAE-sam_model-2024-11-23-epoch=19-val_loss=0.01.ckpt"
model = ConvVAE.load_from_checkpoint(
    checkpoint_path=checkpoint_path,
    z_size=64
).to(device)
model.eval()

ConvVAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
  )
  (fc_mu): Linear(in_features=4096, out_features=64, bias=True)
  (fc_logvar): Linear(in_features=4096, out_features=64, bias=True)
  (fc_decode): Linear(in_features=64, out_features=4096, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), p

In [5]:
def normalize_image(tensor):
    tensor = tensor - tensor.min()  # Ajusta para começar em 0
    tensor = tensor / tensor.max()  # Ajusta para terminar em 1
    return tensor.clamp(0, 1)  # Garante que os valores estejam no intervalo [0, 1]

# Função para gerar imagem a partir de z
def generate_image(z_values):
    z_tensor = torch.tensor([z_values], dtype=torch.float32).to(device)
    with torch.no_grad():
        reconstructed = model.decode(z_tensor) # decodificar a partir do Z
        reconstructed = normalize_image(reconstructed.squeeze(0)) # normalizar
        img = ToPILImage()(reconstructed) # converter para PIL
    return img

def pil_to_bytes(image):
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    return buffer.getvalue()

In [6]:
# Preparar a imagem de entrada
original_img_pil = ToPILImage()(sample_img.squeeze(0))  # Converte o tensor original para PIL
original_image_bytes = pil_to_bytes(original_img_pil)  # Converte PIL para bytes

# Exibir a imagem original à esquerda
original_image_display = widgets.Image(value=original_image_bytes, format="png", layout=widgets.Layout(width="256px", height="256px"))

# Inicializar sliders com valores iniciais de z
mu, logvar = model.encode(sample_img.to(device))  # Obter mu e logvar
z_initial = model.reparameterize(mu, logvar).squeeze(0).tolist()  # Amostra inicial de Z

latent_dim = model.z_size
sliders = [
    widgets.FloatSlider(value=z_initial[i], min=-3.0, max=3.0, step=0.1, description=f"z_{i}")
    for i in range(latent_dim)
]

In [7]:
# Função para atualizar a imagem reconstruída ao alterar os sliders
def update_reconstructed_image(*args):
    z_values = [slider.value for slider in sliders]
    reconstructed_img = generate_image(z_values)
    reconstructed_image_display.value = pil_to_bytes(reconstructed_img)

# Conectar os sliders à função de atualização
for slider in sliders:
    slider.observe(update_reconstructed_image, names="value")

# Exibir sliders no centro
sliders_box = widgets.VBox(sliders)

# Exibir a imagem reconstruída à direita
reconstructed_image_display = widgets.Image(layout=widgets.Layout(width="256px", height="256px"))

# Layout final
layout = widgets.HBox([original_image_display, sliders_box, reconstructed_image_display], layout=widgets.Layout(align_items="center"))
display(layout)

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00@\x00\x00\x00@\x08\x02\x00\x00\x00…