In [None]:
%load_ext kedro.extras.extensions.ipython

In [None]:
%reload_kedro

In [None]:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html#Building-the-autoencoder

In [None]:
catalog.list()

In [None]:
context.catalog.load('image_embeddings')

In [None]:
from gid_ml_framework.image_embeddings.data.hm_data import HMDataset
from gid_ml_framework.image_embeddings.model.pl_autoencoder_module import LitAutoEncoder
from gid_ml_framework.image_embeddings.model import pl_encoders, pl_decoders
import mlflow.pytorch
from mlflow.tracking import MlflowClient
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchvision import transforms
import mlflow
import torch
import numpy as np

In [None]:
RUN_ID=context.catalog.load('params:inference.RUN_ID')

In [None]:
import os
os.getcwd()
os.chdir('..')

In [None]:
logged_model_uri = f'runs:/{RUN_ID}/model'
loaded_model = mlflow.pytorch.load_model(logged_model_uri)

hm_dataset = HMDataset('data/01_raw/images_128_128', transform=transforms.ToTensor())
hm_dataloader = DataLoader(dataset=hm_dataset, drop_last=False, shuffle=False, num_workers=8, batch_size=32)

In [None]:
trainer = pl.Trainer(max_epochs=1, logger=False)
predictions = trainer.predict(loaded_model, dataloaders=hm_dataloader)

In [None]:
type(predictions)

In [None]:
sample = predictions[:25]

In [None]:
out_emb = []
out_labels = []

In [None]:
for emb, labels in sample:
    out_emb.append(emb[0])
    out_labels.append(labels[0])

In [None]:
torch.stack(out_emb).shape

In [None]:
out_labels[:5]

In [None]:
sample2 = predictions[:64]

In [None]:
out_emb = []
out_labels = []

In [None]:
for emb, labels in sample2:
    out_emb.append(emb[0])
    out_labels.append(labels[0])

In [None]:
torch.stack(out_emb).shape

In [None]:
out_emb = []
out_labels = []

In [None]:
for emb, labels in predictions:
    out_emb.append(emb)
    out_labels.append(labels)

In [None]:
torch.cat(out_emb).shape

In [None]:
from itertools import chain

In [None]:
len(list(chain(*out_labels)))

In [None]:
torch.cat(out_emb).numpy().shape

In [None]:
np.array(list(chain(*out_labels))).shape

In [None]:
import pandas as pd

In [None]:
client = MlflowClient()
run = client.get_run(RUN_ID)

In [None]:
article_ids = [article_id.split('.')[0] for article_id in chain(*out_labels)]

In [None]:
column_names = [f'emb_{i+1}' for i in range(int(run.data.params['embedding_size']))]

In [None]:
pd.DataFrame(data=torch.cat(out_emb).numpy(), index=article_ids, columns=column_names).to_parquet('adsada.pq', engine='pyarrow')

In [None]:
np.stack([torch.cat(out_emb).numpy(), np.array(list(chain(*out_labels)))])

In [None]:
out_labels

In [None]:
import torch
import mlflow.pytorch

In [None]:
import os

os.chdir('..')
os.getcwd()

In [None]:
import os
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

In [None]:
import os
from pathlib import Path
from PIL import Image
from torchvision import transforms

In [None]:
class HMDataset(Dataset):
    
    def __init__(self, img_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.img_names = self._get_img_names()
        self.transform = transform
        
    def _get_img_names(self):
        return [img_name.name for img_name in self.img_dir.iterdir()]

    def __getitem__(self, idx):
        img_path = self.img_dir / self.img_names[idx]
        img = Image.open(img_path)
        
        if self.transform is not None:
            img = self.transform(img)
        
        article_id = img_path.name
        return img, article_id

    def __len__(self):
        return len(self.img_names)

In [None]:
dataset = HMDataset('/Users/mmadej/Desktop/Projects/gid-ml-framework/data/01_raw/images_128_128/', transform=transforms.ToTensor())

In [None]:
dataset

In [None]:
dataloader = DataLoader(dataset=dataset,
                          batch_size=32,
                          drop_last=True,
                          shuffle=True,
                          num_workers=0)

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [None]:
class LitAutoEncoder(pl.LightningModule):

    def __init__(self,
                 base_channel_size: int,
                 latent_dim: int,
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 3,
                 width: int = 32,
                 height: int = 32):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """
        Given a batch of images, this function returns the reconstruction loss (MSE in our case)
        """
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.2,
                                                         patience=20,
                                                         min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('test_loss', loss)

In [None]:
class GenerateCallback(pl.Callback):

    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs # Images to reconstruct during training
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

In [None]:
def train_cifar(latent_dim):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, f"cifar10_{latent_dim}"),
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=500,
                         callbacks=[ModelCheckpoint(save_weights_only=True),
                                    GenerateCallback(get_train_images(8), every_n_epochs=10),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, f"cifar10_{latent_dim}.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = Autoencoder.load_from_checkpoint(pretrained_filename)
    else:
        model = Autoencoder(base_channel_size=32, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result

In [None]:
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(128 * 128 * 3, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 128 * 128 * 3))

In [None]:
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

In [None]:
dataloader

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)

In [None]:
trainer.fit(model=autoencoder, train_dataloaders=dataloader)

In [None]:
trainer

In [None]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

In [None]:
# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

In [None]:
fake_image_batch = Tensor(4, 128 * 128 * 3)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

In [None]:
from torch import randn

In [None]:
encoder(randn(1, 128*128*3))

In [None]:
encoder(randn(1, 128*128*3))