In [55]:
import os
import torch
import torch.nn as nn
from torch import Tensor
import torchvision
from torchvision.transforms.v2 import ToDtype, Compose, Resize, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip, Transform
from torchvision.io import read_image
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import lightning as l
import kornia as k
import kornia.augmentation as ka
import torch.utils.data as data
import typing as t
import pathlib as pb
from glob import glob

In [41]:
ROOT_DIR = pb.Path('.')

In [70]:
class Apple2OrangeDataset(data.Dataset):
    def __init__(self, path: pb.Path, split: t.Literal['train', 'test'], kind: t.Literal['apple', 'orange']) -> None:
        super().__init__()
        k = 'A' if kind == 'apple' else 'B'
        self.object = kind
        self.split = split
        self.path = path / 'apple2orange' / f'{split}{k}'
        self.imgpaths = glob(str(self.path) + '/' + '*.jpg')
        transform = []
        if split == 'train':
            transform.append(Resize([286, 286], antialias=True))
            transform.append(RandomCrop([256, 256]))
            transform.append(RandomHorizontalFlip())
            transform.append(RandomVerticalFlip())
        transform.append(ToDtype(torch.float32, scale=True))
        self.transform = Compose(transform)

    def __getitem__(self, key: int) -> Tensor:
        image: Tensor = read_image(self.imgpaths[key], torchvision.io.ImageReadMode.RGB)
        image = self.transform(image)
        return image

    def __len__(self) -> int:
        return len(self.imgpaths)


# Create subsets
train_x = Apple2OrangeDataset(ROOT_DIR, 'train', 'apple')
train_y = Apple2OrangeDataset(ROOT_DIR, 'train', 'orange')
test_x = Apple2OrangeDataset(ROOT_DIR, 'test', 'apple')
test_y = Apple2OrangeDataset(ROOT_DIR, 'test', 'orange')

In [78]:
class Encode(nn.Module):
    def __init__(self, input_channel, output_channel) -> None:
        super(Encode, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), 2)
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)


class Decode(nn.Module):
    def __init__(self, input_channel, output_channel) -> None:
        super(Decode, self).__init__()
        self.layers = nn.Sequential(
            nn.ConvTranspose2d(input_channel, output_channel, 3, 2, 1),
            nn.ReLU(),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.layers(x)


class Generator(nn.Module):
    def __init__(self, ) -> None:
        super(Generator, self).__init__()

        self.encoder = nn.ModuleList([
            Encode(3, 64),
            Encode(64, 128),
            Encode(128, 256),
            Encode(256, 512),
        ])

        self.decoder = nn.ModuleList([
            Decode(512 + 256, 256),
            Decode(256 + 128, 128),
            Decode(128 + 64, 3),
        ])

        self.activ_fn = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        residuals: t.List[Tensor] = []

        # Encode the input
        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i != len(self.encoder) - 1:
                residuals.append(x)

        # Decode the activation maps
        for i, layer in enumerate(self.decoder):
            x = layer(torch.cat([residuals[i], x], dim=1))

        x = self.activ_fn(x)
        return x


In [81]:
gen = Generator()
gen.forward(torch.randn((1, 3, 256, 256)))

In [80]:
gen

Generator(
  (encoder): ModuleList(
    (0): Encode(
      (layers): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): Encode(
      (layers): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): Encode(
      (layers): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (3): Encode(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0,