In [1]:
import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
import os
import random

In [2]:
# pytorch settings
run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3

# model settings
BOTTLENECK_DIM = 28*28

# training settings
BATCH_SIZE = 128
EPOCHS = 512
optim = torch.optim.AdamW
lr = 0.001

class PerceptualLoss(nn.Module):
    def __init__(self, weights=None):
        super().__init__()
        self.vgg = vgg16(pretrained=True).features.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

        self.layers = {
            "0": "relu1_1",
            "3": "relu1_2",
            "8": "relu2_2",
            "15": "relu3_3"
        }

        self.layer_weights = weights or {
            "relu1_1": 1.5,
            "relu1_2": 1.0,
            "relu2_2": 0.8,
            "relu3_3": 0.3,
        }

    def forward(self, x, y):
        with torch.no_grad():
            loss = 0.0
            for i, layer in enumerate(self.vgg):
                x, y = layer(x), layer(y)
                name = self.layers.get(str(i))
                if name:
                    loss += self.layer_weights[name] * F.mse_loss(x, y)
                if i > max(map(int, self.layers.keys())):
                    break
        return loss


class CombinedLoss(nn.Module):
    def __init__(self, perceptual_weight=1.8, mse_weight=0.2, latent_weight=0.0):
        super().__init__()
        self.perceptual_loss = PerceptualLoss().to(run_device)
        self.mse_loss = nn.MSELoss()
        self.perceptual_weight = perceptual_weight
        self.mse_weight = mse_weight
        self.latent_weight = latent_weight

    def forward(self, x_recon, x_target, z=None, z_target=None):
        loss = (
            self.perceptual_weight * self.perceptual_loss(x_recon, x_target)
            + self.mse_weight * self.mse_loss(x_recon, x_target)
        )
        if self.latent_weight > 0 and z is not None and z_target is not None:
            loss += self.latent_weight * F.mse_loss(z, z_target)
        return loss

loss = CombinedLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR

# tensorboard settings
TENSORBOARD_LOG_DIR = "runs/autoencoder/exp16"



In [3]:
class ImageProcessor:
    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor to a PIL Image.
        
        Args:
            image_tensor (torch.Tensor): A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        
        Returns:
            Image.Image: A PIL Image object.
        """
        # Clamp to [0, 1], convert to [0, 255] and uint8
        image_np = (image_tensor.clamp(0, 1).mul(255).byte().cpu().permute(1, 2, 0).numpy())
        return Image.fromarray(image_np)
    
    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL image to a PyTorch tensor of shape (C, H, W) with values in [0, 1].
        
        Args:
            image (Image.Image): A PIL Image object.
        
        Returns:
            torch.Tensor: A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        """
        return transforms.ToTensor()(image)  # Already returns (C, H, W)

In [4]:
class ImageGenerator:
    def __init__(self, folder_path: str):
        self.images = [Image.open(os.path.join(folder_path, fname)).convert("RGBA") for fname in os.listdir(folder_path) if fname.lower().endswith(".png")]
    
    def generate_random_image(
        self,
        resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT),
        num_layers_range=(2, 24),
        scale_fraction_range=(0.15, 1.2),  # as % of canvas dimensions
        allow_rotation=True,
        apply_color_tint=True
    ) -> Image.Image:
        """
        Generate a composite RGBA image with random base color and transformed overlays.
        
        Args:
            resolution (tuple): Output image size (width, height).
            num_layers_range (tuple): Min and max number of overlays to paste.
            scale_fraction_range (tuple): Min and max fraction of resolution to scale overlays.
            allow_rotation (bool): Whether to apply random rotation to overlays.
            apply_color_tint (bool): Whether to apply random color tints.

        Returns:
            PIL.Image: Final composited RGBA image.
        """
        # Create base RGBA image with a random opaque color
        base_color = tuple(random.randint(0, 255) for _ in range(3)) + (255,)
        base = Image.new("RGBA", resolution, base_color)

        num_layers = random.randint(*num_layers_range)

        for _ in range(num_layers):
            overlay = random.choice(self.images).copy().convert("RGBA")

            # Random scaling based on canvas size
            scale_w = random.uniform(*scale_fraction_range)
            scale_h = random.uniform(*scale_fraction_range)
            new_size = (
                int(resolution[0] * scale_w),
                int(resolution[1] * scale_h)
            )
            overlay = overlay.resize(new_size, resample=Image.BICUBIC)

            # Optional rotation
            
            if allow_rotation:
                angle = random.uniform(0, 360)
                overlay = overlay.rotate(angle, expand=True)

            # Optional color tint (multiply RGB channels)
            if apply_color_tint:
                r, g, b, a = overlay.split()
                tint_factors = [random.uniform(0.25, 1.0) for _ in range(3)]
                r = r.point(lambda i: int(i * tint_factors[0]))
                g = g.point(lambda i: int(i * tint_factors[1]))
                b = b.point(lambda i: int(i * tint_factors[2]))
                
                overlay = Image.merge("RGBA", (r, g, b, a))

            # Allow pasting outside canvas
            offset_x = random.randint(-overlay.size[0] // 2, resolution[0])
            offset_y = random.randint(-overlay.size[1] // 2, resolution[1])
            pos = (offset_x, offset_y)

            # Create transparent layer and paste overlay
            temp_layer = Image.new("RGBA", resolution, (0, 0, 0, 0))
            temp_layer.paste(overlay, pos, overlay)

            # Composite with base
            base = Image.alpha_composite(base, temp_layer)

        return base.convert("RGB")

In [5]:
class ImageTargetLatentCompressor:
    def __init__(self):
        """
        A reference compressor that encodes RGB images into a fixed-size 28×28 grayscale latent
        by squeezing R, G, B channels side-by-side horizontally, with any remaining space padded.
        """
        self.encoding_size = int(BOTTLENECK_DIM ** 0.5)
        self.channel_width = self.encoding_size // CHANNELS
        self.padding_width = self.encoding_size - (self.channel_width * CHANNELS)

    def image_to_latent(self, image: Image.Image) -> torch.Tensor:
        """
        Compress an RGB image into a 28×28 grayscale latent.

        Args:
            image (PIL.Image): Input RGB image.

        Returns:
            torch.Tensor: Shape (1, 1, 28, 28), grayscale stacked channel encoding.
        """
        image = image.resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT)).convert("RGB")
        tensor = transforms.ToTensor()(image).unsqueeze(0)  # (1, 3, H, W)

        r = tensor[:, 0:1]
        g = tensor[:, 1:2]
        b = tensor[:, 2:3]

        r_comp = F.interpolate(r, size=(self.encoding_size, self.channel_width), mode='bilinear', align_corners=False)
        g_comp = F.interpolate(g, size=(self.encoding_size, self.channel_width), mode='bilinear', align_corners=False)
        b_comp = F.interpolate(b, size=(self.encoding_size, self.channel_width), mode='bilinear', align_corners=False)

        components = [r_comp, g_comp, b_comp]

        if self.padding_width > 0:
            pad = torch.zeros((1, 1, self.encoding_size, self.padding_width), dtype=r_comp.dtype)
            components.append(pad)

        encoded = torch.cat(components, dim=3)  # (1, 1, 28, 28)
        return encoded

    def latent_to_image(self, latent: torch.Tensor) -> Image.Image:
        """
        Decode a 28×28 latent into a full-resolution RGB image.

        Args:
            latent (torch.Tensor): Shape (1, 1, 28, 28)

        Returns:
            PIL.Image: RGB image of size (128, 128)
        """
        assert latent.shape == (1, 1, self.encoding_size, self.encoding_size), f"Expected latent shape (1, 1, {self.encoding_size}, {self.encoding_size})"

        r_slice = latent[:, :, :, 0:self.channel_width]
        g_slice = latent[:, :, :, self.channel_width:self.channel_width * 2]
        b_slice = latent[:, :, :, self.channel_width * 2:self.channel_width * 3]

        r_up = F.interpolate(r_slice, size=(RESOLUTION_HEIGHT, RESOLUTION_WIDTH), mode='bilinear', align_corners=False)
        g_up = F.interpolate(g_slice, size=(RESOLUTION_HEIGHT, RESOLUTION_WIDTH), mode='bilinear', align_corners=False)
        b_up = F.interpolate(b_slice, size=(RESOLUTION_HEIGHT, RESOLUTION_WIDTH), mode='bilinear', align_corners=False)

        reconstructed = torch.cat([r_up, g_up, b_up], dim=1)  # (1, 3, H, W)
        reconstructed_image = transforms.ToPILImage()(reconstructed.squeeze(0).clamp(0, 1))
        return reconstructed_image

In [6]:
class RandomImageDataset(Dataset):
    def __init__(self, folder_path: str, length: int):
        self.length = length
        
        self.generator = ImageGenerator(folder_path)
        self.processor = ImageProcessor()
        self.latent_compressor = ImageTargetLatentCompressor()

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img = self.generator.generate_random_image()
        
        return self.processor.pil_to_tensor(img), self.latent_compressor.image_to_latent(img)

In [7]:
class Trainer:
    def __init__(
        self,
        dataset,
        model: nn.Module,
        criterion: nn.Module = loss,
        optimizer_class=optim,
        scheduler_class=scheduler,
        lr: float = lr,
        epochs: int = EPOCHS,
        batch_size: int = BATCH_SIZE,
        device: torch.device = run_device,
        writer: SummaryWriter = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR)
    ):
        self.dataset = dataset
        self.model = model.to(device)
        self.criterion = criterion
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = device
        self.writer = writer

        self.optimizer = optimizer_class(self.model.parameters(), lr=lr)
        self.scheduler = scheduler_class(self.optimizer, T_max=epochs) if scheduler_class else None

        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        self.losses = []

    def train(self):
        self.model.train()

        for epoch in range(self.epochs):
            total_loss = 0.0

            for batch in self.dataloader:
                x, z_target = batch  # x = image tensor, z_target = handcrafted latent
                x = x.to(self.device)
                z_target = z_target.to(self.device)

                self.optimizer.zero_grad()
                z = self.model.encode(x)  # predicted latent
                loss = self.criterion(z, z_target)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(self.dataloader)
            self.losses.append(avg_loss)

            # Log loss and LR
            self.writer.add_scalar("Loss/train", avg_loss, epoch)
            current_lr = self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar("LearningRate", current_lr, epoch)

            if self.scheduler:
                self.scheduler.step()

            # Log input/reconstruction from full autoencoder pass
            with torch.no_grad():
                x_sample, _ = next(iter(self.dataloader))
                x_sample = x_sample[:8].to(self.device)
                z_sample = self.model.encode(x_sample)
                x_recon = self.model.decode(z_sample)

                self.writer.add_images("Input", x_sample.clamp(0, 1), epoch)
                self.writer.add_images("Reconstruction", x_recon.clamp(0, 1), epoch)

        self.writer.close()
        self._plot_losses()

    def _plot_losses(self):
        plt.plot(self.losses, marker="o")
        plt.title("Training Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.show()

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=BOTTLENECK_DIM):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),   # 64x64 -> 32x32
            nn.ReLU(),
            ResidualBlock(64),

            nn.Conv2d(64, 128, 4, 2, 1),           # -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),

            nn.Conv2d(128, 256, 4, 2, 1),          # -> 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),

            nn.Conv2d(256, 512, 4, 2, 1),          # -> 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(),
            ResidualBlock(512),
        )
        self.encoder_fc = nn.Linear(512 * 4 * 4, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 4, 4)),
            ResidualBlock(512),

            nn.ConvTranspose2d(512, 256, 4, 2, 1), # -> 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            ResidualBlock(256),

            nn.ConvTranspose2d(256, 128, 4, 2, 1), # -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64),

            nn.ConvTranspose2d(64, in_channels, 4, 2, 1),  # -> 64x64
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        return self.encoder_fc(x)

    def decode(self, z):
        z = self.decoder_fc(z)
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z

In [9]:
model = ConvAutoencoder(latent_dim=BOTTLENECK_DIM).to(run_device)

dataset = RandomImageDataset("rand_img_components", 1800)
trainer = Trainer(dataset, model)

In [10]:
trainer.train()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x32768 and 8192x784)

In [None]:
img_proc = ImageProcessor()

In [None]:
img_proc.tensor_to_pil(test_img)

In [None]:
#save model to file
torch.save(model, "autoencoder.pth")

In [None]:
tnsr = torch.randn(768, device=torch.device("cuda"))

In [None]:
# load PIL from file and reshape to RGB 128x128 (single use)
test_img = Image.open("test.png").convert("RGB").resize((128, 128))

In [None]:
noise = torch.zeros(2560, device=run_device)

In [None]:
# run test_img thru autoencoder and display:
encoded = model.encode(img_proc.pil_to_tensor(test_img).unsqueeze(0).to(run_device))
noise[random.randint(0, 767)] += random.randint(-10, 10) * 0.6  # set first dimension to 1
encoded = encoded + noise * 0
decoded = model.decode(encoded).squeeze(0).clamp(0, 1)
img_proc.tensor_to_pil(decoded)

In [None]:
test_img