In [None]:
%load_ext kedro.extras.extensions.ipython

In [None]:
%reload_kedro

In [None]:
catalog.list()

In [None]:
my_dict = context.catalog.load('parameters')['input']

In [None]:
img_path, suffix = my_dict['image_path'], my_dict['suffix']

In [None]:
img_path, suffix

In [None]:
import os
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

In [None]:
import os
from pathlib import Path
from PIL import Image
from torchvision import transforms

In [None]:
class HMDataset(Dataset):
    
    def __init__(self, img_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.img_names = self._get_img_names()
        self.transform = transform
        
    def _get_img_names(self):
        return [img_name.name for img_name in self.img_dir.iterdir()]

    def __getitem__(self, idx):
        img_path = self.img_dir / self.img_names[idx]
        img = Image.open(img_path)
        
        if self.transform is not None:
            img = self.transform(img)
        
        article_id = img_path.name
        return img, article_id

    def __len__(self):
        return len(self.img_names)

In [None]:
dataset = HMDataset('/Users/mmadej/Desktop/Projects/gid-ml-framework/data/01_raw/images_128_128/', transform=transforms.ToTensor())

In [None]:
dataset

In [None]:
dataloader = DataLoader(dataset=dataset,
                          batch_size=32,
                          drop_last=True,
                          shuffle=True,
                          num_workers=0)

In [None]:
single_image = '/Users/mmadej/Desktop/Projects/gid-ml-framework/data/01_raw/images_128_128/0762212004.jpg'

In [None]:
single_image = '/Users/mmadej/Desktop/Projects/gid-ml-framework/data/01_raw/images_128_128/0654410024.jpg'

In [None]:
img = Image.open(single_image)

In [None]:
img

In [None]:
img.size

In [None]:
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(128 * 128 * 3, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 128 * 128 * 3))

In [None]:
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

In [None]:
dataloader

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)

In [None]:
trainer.fit(model=autoencoder, train_dataloaders=dataloader)

In [None]:
trainer

In [None]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

In [None]:
# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

In [None]:
fake_image_batch = Tensor(4, 128 * 128 * 3)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

In [None]:
from torch import randn

In [None]:
encoder(randn(1, 128*128*3))

In [None]:
encoder(randn(1, 128*128*3))