## Imports

In [1]:
import os
import random
import numpy as np
import gc
from pathlib import Path

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
import lightning as L

from typing import Optional

from scipy.ndimage import gaussian_gradient_magnitude, laplace

from minerva.data.datasets.supervised_dataset import SupervisedReconstructionDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.transforms.transform import _Transform
from minerva.pipelines.lightning_pipeline import SimpleLightningPipeline
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt

print("PyTorch Version:", torch.__version__)
print("CUDA Version:", torch.version.cuda)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Device Count:", torch.cuda.device_count())
print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA Device")



PyTorch Version: 2.5.1+cu124
CUDA Version: 12.4
CUDA Available: True
CUDA Device Count: 1
CUDA Device Name: NVIDIA GeForce RTX 4090


## Variables

In [2]:
# f3
train_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/images"
annotation_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/annotations"

# seam-ai (parihaka)
# train_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
# annotation_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"

model_name = "vae"
height, width = 255, 701 # f3
# height, width = 1006, 590 # parihaka
num_epochs = 1000
batch_size = 4

## Transforms

In [3]:
class Padding(_Transform):
    def __init__(self, target_h_size: int, target_w_size: int):
        self.target_h_size = target_h_size
        self.target_w_size = target_w_size

    def __call__(self, x: np.ndarray) -> np.ndarray:
        h, w = x.shape[:2]
        pad_h = max(0, self.target_h_size - h)
        pad_w = max(0, self.target_w_size - w)
        if len(x.shape) == 2:
            padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect")
            padded = np.expand_dims(padded, axis=2)
            padded = torch.from_numpy(padded).float()
        else:
            padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
            padded = torch.from_numpy(padded).float()

        padded = np.transpose(padded, (2, 0, 1))
        return self.normalize_data(padded)
    
    def normalize_data(self, data, target_min=-1, target_max=1):
        """
        Função responsável por normalizar as imagens no intervalo (-1,1)
        """
        data_min, data_max = data.min(), data.max()
        return target_min + (data - data_min) * (target_max - target_min) / (data_max - data_min)

## Init model

In [4]:
"""
Variational encoder model, used as a visual model
for our model of the world.
"""
class Decoder(nn.Module):
    """ VAE decoder """
    def __init__(self, img_channels, latent_size):
        super(Decoder, self).__init__()
        self.latent_size = latent_size
        self.img_channels = img_channels

        self.fc1 = nn.Linear(latent_size, 14 * 41 * 256)
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1)

    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.fc1(x))
        # x = x.unsqueeze(-1).unsqueeze(-1)
        x = x.view(-1, 256, 14, 41)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        reconstruction = F.sigmoid(self.deconv4(x))
        reconstruction = F.interpolate(reconstruction, size=(255, 701), mode="bilinear", align_corners=False)
        return reconstruction

class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
    """ VAE encoder """
    def __init__(self, img_channels, latent_size):
        super(Encoder, self).__init__()
        self.latent_size = latent_size
        #self.img_size = img_size
        self.img_channels = img_channels

        self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2)

        self.fc_mu = nn.Linear(14*41*256, latent_size) # 2*2*256
        self.fc_logsigma = nn.Linear(14*41*256, latent_size) # 2*2*256


    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)

        mu = self.fc_mu(x)
        logsigma = self.fc_logsigma(x)

        return mu, logsigma

class VAE(L.LightningModule):
    """ Variational Autoencoder """
    def __init__(self, img_channels, latent_size, lr=1e-3):
        super(VAE, self).__init__()
        self.encoder = Encoder(img_channels, latent_size)
        self.decoder = Decoder(img_channels, latent_size)
        self.lr = lr

    def forward(self, x): # pylint: disable=arguments-differ
        mu, logsigma = self.encoder(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)

        recon_x = self.decoder(z)
        return recon_x, mu, logsigma
    
    def _loss_function(self, recon_x, x, mu, logsigma):
        BCE = F.mse_loss(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
        return BCE + KLD
    
    def _step(self, batch, batch_idx:int, step_name:str):
        data = batch[0] # image -> torch.Size([B, C, H, W])
        # label = batch[1] # label 
        recon_batch, mu, logsigma = self.forward(data)
        loss = self._loss_function(recon_batch, data, mu, logsigma)
        self.log(f"{step_name}_loss", loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "val")
    
    def test_step(self, batch: torch.Tensor, batch_idx: int):
        return self._step(batch, batch_idx, "test")

    def predict_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None):
        data = batch
        outputs = self.forward(data)
        return outputs
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr)
        return optimizer

In [5]:
model = VAE(3, 512)

## Data module

In [6]:
class DataModule(L.LightningDataModule):
    def __init__(
        self,
        train_path: str,
        annotations_path: str,
        transforms: _Transform = None,
        batch_size: int = 1,
        num_workers: int = None,
    ):
        super().__init__()
        self.train_path = Path(train_path)
        self.annotations_path = Path(annotations_path)
        self.transforms = transforms
        self.batch_size = batch_size
        self.num_workers = (
            num_workers if num_workers is not None else os.cpu_count()
        )

        self.datasets = {}

    def setup(self, stage=None):
        if stage == "fit":
            train_img_reader = TiffReader(self.train_path / "train")
            train_label_reader = PNGReader(self.annotations_path / "train")
            train_dataset = SupervisedReconstructionDataset(
                readers=[train_img_reader, train_label_reader],
                transforms=self.transforms,
            )

            val_img_reader = TiffReader(self.train_path / "val")
            val_label_reader = PNGReader(self.annotations_path / "val")
            val_dataset = SupervisedReconstructionDataset(
                readers=[val_img_reader, val_label_reader],
                transforms=self.transforms,
            )

            self.datasets["train"] = train_dataset
            self.datasets["val"] = val_dataset

        elif stage == "test" or stage == "predict":
            test_img_reader = TiffReader(self.train_path / "test")
            test_label_reader = PNGReader(self.annotations_path / "test")
            test_dataset = SupervisedReconstructionDataset(
                readers=[test_img_reader, test_label_reader],
                transforms=self.transforms,
            )
            self.datasets["test"] = test_dataset
            self.datasets["predict"] = test_dataset

        else:
            raise ValueError(f"Invalid stage: {stage}")

    def train_dataloader(self):
        return DataLoader(
            self.datasets["train"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.datasets["val"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.datasets["test"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.datasets["predict"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

## Applying transforms in data module

In [7]:
# transform_train = transforms.Compose([
#     transforms.ToPILImage(),
#     Padding(height, width),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
# ])

data_module = DataModule(
    train_path=train_path,
    annotations_path=annotation_path,
    transforms=Padding(height, width),
    batch_size=batch_size,
)

In [8]:
def get_train_dataloader(data_module):
    data_module.setup("fit")
    return data_module.train_dataloader()

print("Total batches: ", len(get_train_dataloader(data_module)))

train_batch_x, train_batch_y = next(iter(get_train_dataloader(data_module)))
print("train_batch_x shape: ", train_batch_x.shape)

Total batches:  248


train_batch_x shape:  torch.Size([4, 3, 255, 701])


## Train model

In [9]:
from lightning.pytorch.callbacks import ModelCheckpoint
from datetime import datetime

# current_date = datetime.now().strftime("%Y-%m-%d")

# # Define o callback para salvar o modelo com base no menor valor da métrica de validação
# checkpoint_callback = ModelCheckpoint(
#     monitor="val_loss", # Métrica para monitorar
#     dirpath="./checkpoints", # Diretório onde os checkpoints serão salvos
#     filename=f"convVAE-sam_model-{current_date}-{{epoch:02d}}-{{val_loss:.2f}}", # Nome do arquivo do checkpoint
#     save_top_k=1, # Quantos melhores checkpoints salvar (no caso, o melhor)
#     mode="min", # Como a métrica deve ser tratada (no caso, 'min' significa que menor valor de val_loss é melhor)
# )

# from lightning.pytorch.loggers import TensorBoardLogger
# # Defina o logger do TensorBoard
# logger = TensorBoardLogger("logs", name="sam_model")

# from lightning.pytorch.loggers import CSVLogger

# logger = CSVLogger("logs", name="conv_vae")

trainer = L.Trainer(
    max_epochs=num_epochs,
    accelerator="gpu",
    devices=1,
    # logger=logger,
    # callbacks=[checkpoint_callback],
)
# trainer.fit(model, data_module)

pipeline = SimpleLightningPipeline(
    model=model,
    trainer=trainer,
    save_run_status=True
)

pipeline.run(data=data_module, task="fit")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/fabric/utilities/seed.py:42: No seed found, seed set to 0
Seed set to 0
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Log directory set to: /workspaces/Minerva-Discovery/my_experiments/vae_v1/notebook/lightning_logs/version_12
Pipeline info saved at: /workspaces/Minerva-Discovery/my_experiments/vae_v1/notebook/lightning_logs/version_12/run_2024-12-21-17-06-1859a3bb9cbe2342d490e006b98f774adb.yaml



  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | encoder | Encoder | 151 M  | train
1 | decoder | Decoder | 76.1 M | train
--------------------------------------------
227 M     Trainable params
0         Non-trainable params
227 M     Total params
908.936   Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


Epoch 999: 100%|██████████| 248/248 [00:09<00:00, 26.82it/s, v_num=12, train_loss=5.21e+4, val_loss=5.22e+4]

`Trainer.fit` stopped: `max_epochs=1000` reached.


Epoch 999: 100%|██████████| 248/248 [00:15<00:00, 15.82it/s, v_num=12, train_loss=5.21e+4, val_loss=5.22e+4]
Pipeline info saved at: /workspaces/Minerva-Discovery/my_experiments/vae_v1/notebook/lightning_logs/version_12/run_2024-12-21-17-06-1859a3bb9cbe2342d490e006b98f774adb.yaml


## Test model

In [10]:
# sem normalize = 230 epocas deu val_loss=5.27e+4
# com normalize = 230 epocas deu val_loss=5.24e+4
# com normalize = 230 epocas e espaco latente 512, deu val_loss=