<font size="6"> 
<b>
Picture Monetization
</b>
</font>

The goal with this project is to implement a Generative Adversarial Network between two dataset: a dataset consisting of 7038 pictures (`photo`) and a dataset consisting of 300 Monet paintings (`monet`). To achieve this, we will implement two __generators__, capable of taking images from one dataset and outputing an image following the same distribution as the other dataset, and one __discriminator__, which will be trained to determine whether an image is fake or real.

To start, install the following packages by running the code below.

`pip install ipykernel numpy torch albumentations pillow tqdm torchvision transformers absl-py`

Run the following imports to load all packages.

In [1]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import copy

import numpy as np

import os

from PIL import Image

import random

import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from torchvision.utils import save_image

from tqdm import tqdm

import multiprocessing as mp

# Force the spawn method (this is required on Windows)
mp.set_start_method('spawn', force=True)

  from .autonotebook import tqdm as notebook_tqdm


<font size="5"> 
<b>
Config
</b>
</font>

The configurations for the GAN. Some variables such as the directories may have to be tweaked to fit your machine.

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Prefer to work on GPU if available
TRAIN_DIR = "Data/train"
VAL_DIR = "Data/val"
BATCH_SIZE = 1 # i.e. SGD
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 5 # 0.0 # weight for identity loss (photo->photo, Monet->Monet)
LAMBDA_CYCLE = 10 # weight for cycle loss (photo->Monet, Monet->photo)
NUM_WORKERS = 0 # may need to change to 0 if no workers are available
NUM_EPOCHS = 10
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_P = "genp.pth.tar" # photos
CHECKPOINT_GEN_M = "genm.pth.tar" # monet
CHECKPOINT_CRITIC_P = "criticp.pth.tar"
CHECKPOINT_CRITIC_M = "criticm.pth.tar"

transforms = A.Compose( # Note that we only expand the dataset by flipping, and not by altering the colors or brightnesses (since these are intregral to the photos and Monet )
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5), # Double dataset by flipping images
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

<font size="5"> 
<b>
Dataset
</b>
</font>

We define the class of the dataset holding the photos and the Monet pictures, along with its generator, length, and item retrieval function. The retrieval finds the photo and Monet picture corresponding to the given index, finds the images from their respective paths, applies the given transformation on them (if any), and returns the two pictures. 

In [3]:
class PhotoMonetDataset(Dataset):
    def __init__(self, root_photo, root_monet, transform=None):
        self.root_photo = root_photo # dir to photos
        self.root_monet = root_monet # dir to monet pictures
        self.transform = transform

        self.photo_images = os.listdir(root_photo)
        self.monet_images = os.listdir(root_monet)
        self.photo_len = len(self.photo_images)
        self.monet_len = len(self.monet_images)
        self.length_dataset = max(self.photo_len, self.monet_len)

    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self,index):
        photo_img = self.photo_images[index % self.photo_len ] # preventing index errors
        monet_img = self.monet_images[index % self.monet_len ]

        photo_path = os.path.join(self.root_photo, photo_img)
        monet_path = os.path.join(self.root_monet, monet_img)
        print(f"Trying to load: {monet_path}")
        if not os.path.exists(monet_path):
            print(f"❌ ERROR: File not found -> {monet_path}")
        photo_img = np.array(Image.open(photo_path).convert("RGB"))
        monet_img = np.array(Image.open(monet_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=photo_img, image0=monet_img)
            photo_img = augmentations["image"]
            monet_img = augmentations["image0"]

        return photo_img, monet_img

<font size="5"> 
<b>
Models
</b>
</font>

The GAN essentially consists of two models: a discriminator, which decides if an image is fake or not, and a generator, which generates fake images by converting an image from distribution (e.g. photo) to another (e.g. Monet painting). They work together to correctly train the model, see the below section on __Training__.

<font size="4"> 
<b>
Discriminator model
</b>
</font>

__Block__: Shorthand for a convolution block, consisting of a 2d convolution layer, a normalization, and a ReLU. Used to define the discriminatory model.

__Discriminator__: Classifies images into "Real" and "Fake". It consists of blocks of iteratively smaller size, finally culminating in a one-dimensional output (0 is "fake" and 1 is "true").

In [4]:
class Block(nn.Module):   # inheriting from nn. Module
    def __init__(self, in_channels, out_channels, stride ):
        super().__init__()                          # is a way to call the constructor of a parent class in Python. It ensures that the parent class (nn.Module in PyTorch) is properly initialized when a child class is created.
        self.conv =nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(nn.Conv2d(in_channels,features[0],kernel_size=4,stride=2,padding=1, padding_mode="reflect"),nn.LeakyReLU(0.2))
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1, padding_mode="reflect")) # output has dimension 1
        self.model = nn.Sequential(*layers) # unwrapping the list

    def forward(self,x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x)) # normalize the output between 0 and 1

    def test():
        x = torch.randn((5, 3, 256, 256))
        model = Discriminator(in_channels=3)
        preds = model(x)
        print(preds.shape)

In [5]:
# Run this block to test the discriminator
Discriminator.test()

torch.Size([5, 1, 30, 30])


<font size="4"> 
<b>
Generator model
</b>
</font>

We wish to remove all "unimportant features" from the input to a bottleneck in the middle (downsampling), and then upscale it to the desired distribution, but keeping the most important features (upsampling). 

__ConvBlock__: Conv2d + ReLU if `down=True`, else the transposed (i.e., the same but in the other direction).

__ResidualBlock__: Consists of two ConvBlock(down=True). Forward has a residual to avoid 0 gradients in deep networks.

__Generator__: Consists of an initial convolution that processes the image input, followed by two down ConvBlocks boiling it down to its most important features, followed by some ResidualBlocks that process the features (without risking gradient losses due to deep networks). Finally, two up ConvBlocks are applied and a last convolution (i.e., the same as the beginning, without ReLU but using `tanh` in the end).

In [6]:
class ConvBlock(nn.Module):    # Down and upsampling
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs) :#key word arguments
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels,out_channels,**kwargs), # either Conv2d (id down) or its transpose
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self,x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block= nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1, stride=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
        )

    def forward(self,x):
        return x+ self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        self.down_blocks = nn.ModuleList(
           [ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)]
        )
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range (num_residuals) ]
        )
        self.up_blocks = nn.ModuleList(
            [ConvBlock(num_features*4, num_features*2,down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
             ConvBlock(num_features*2, num_features*1,down=False, kernel_size=3, stride=2, padding=1, output_padding=1)]
        )
        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7,stride=1, padding=3, padding_mode="reflect")

    def forward (self,x):
        x= self.initial(x)
        for layer in self.down_blocks:
            x= layer(x)
        x=self.residual_blocks(x)
        for layer in self.up_blocks:
            x= layer(x)
        return torch.tanh(self.last(x))

def generator_test():
    img_channels = 3
    img_size =256
    x= torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels,64,9 )
    print(gen(x).shape)


In [7]:
# Run to test the generator
generator_test()

torch.Size([2, 3, 256, 256])


<font size="5"> 
<b>
Utils
</b>
</font>

Utility functions used in training.

In [8]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

<font size="5"> 
<b>
Training
</b>
</font>

Here, we are first training the discriminant and then the generator. We are using the generator to get the loss of the discriminant, and the discriminant to train the generator. Note that for low batch sizes (e.g. 1), the difference between the generator/discriminant before and after the weight updating is quite small, which might be good since we ideally would like to update them "at the same time". This is of course impossible, but by using small batch sizes, we approximate this better.

__Training of the discriminant__

We first try our function on a true photo, then generate a fake photo from the Monet data and try our discriminant on this photo. We then compare the outputed values with the correct ones (1 and 0 respectively) using mse, and let the loss be the sum of the two. We then do the same for Monet and let the total loss be the average of the two losses. Finally, we update the weights based on the losses. This way, we train the discriminant to recognize both fake and true data using the generator. As the generator improves, so will the discriminant.

__Training the generator__

We wish to improve the generator in three regards:

1. __Adverserial__: We want the generator to be able to fool the discriminant. Since the discriminant is improved in the previous iteration, we thus want to generate data and change the weights in the generator such that the discrimant believes it to be true, hence why we take the mse of the fake data wrt 1. Note that by "moving away" from the discriminator, we do not train a generator which can create images that the discriminator is tailored to recognize and converge pretty quickly, but we instead capitalize from our improved discriminator and use it to iteratively improve towards data which more closely resembles the target distribution.
2. __Cycle__: An image is converted to the other distribution, and then back again to the first distribution. It is then compared with the original image using the $L^1$ norm. By doing this, we hope to train the CycleGAN to only make necessary changes to images, and for the two used generators to identify the same features to change.
3. __Identity__: To strengthen our model's capacity to generate images of the target distribution, we input an image of that distribution and similar to above take the $L^1$-loss.

We combine all of these losses by summing them together, along with a weight factor to the cycle `LAMBDA_CYCLE` and identity losses `LAMBDA_IDENTITY` (and by extension also to the adversarial loss since lowering the weights of the others de-facto raises that of the adversarial one, and vice versa).

In [None]:
def train_fn(disc_P, disc_M, gen_P, gen_M, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, epoch):
    loop = tqdm(loader, leave=True)    # progress bar
    for idx, (monet, photo) in enumerate(loop):
        print(f"Entering iteration {idx}")
        photo = photo.to(DEVICE)
        monet = monet.to(DEVICE)

        # Train Discriminators P and M.
        with torch.amp.autocast('cuda'):
            fake_Photo = gen_P(monet) # generate a fake photo based off a Monet picture
            D_P_real = disc_P(photo) # take a real photo and see if it is real or not
            D_P_fake = disc_P(fake_Photo.detach()) # check whether the fake photo is classified as real or not
            D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
            D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))
            D_P_loss = D_P_fake_loss+D_P_real_loss # loss based on how well it classified true and fake images

            fake_Monet = gen_M(photo)
            D_M_real = disc_P(monet)
            D_M_fake = disc_P(fake_Monet.detach())
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
            D_M_loss = D_M_fake_loss + D_M_real_loss

            D_loss = (D_P_loss+D_M_loss)/2 # loss as average of Monet and photo

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update() # Update the discriminator based on the loss

        # Train generators P and M
        with ((torch.amp.autocast('cuda'))):
            # Adverserial loss
            D_P_fake = disc_P(fake_Photo) # whether the photo is fake or not (after updating)
            D_M_fake = disc_M(fake_Monet)
            Loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))
            Loss_G_P = mse(D_P_fake, torch.ones_like(D_P_fake))
            # Cycle loss
            cycle_monet = gen_M(fake_Photo) # make a Monet out of a photo
            cycle_photo = gen_P(fake_Monet)
            cycle_monet_loss = l1(cycle_monet, monet)
            cycle_photo_loss = l1(cycle_photo,photo)
            # Identitiy loss
            identity_photo= gen_P(photo) # generate photo out of a photo
            identity_monet= gen_M(monet)
            identity_monet_loss= l1(identity_photo,photo)
            identity_photo_loss= l1(identity_monet,monet)
            # Add all together
            G_loss = (Loss_G_M+Loss_G_P
            + cycle_monet_loss* LAMBDA_CYCLE
            + cycle_photo_loss* LAMBDA_CYCLE
            + identity_monet_loss * LAMBDA_IDENTITY
            + identity_photo_loss * LAMBDA_IDENTITY) # use everything in loss, with weights

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update() # update based on total loss

        if idx % 20 == 0:
            save_image(fake_Photo * 0.5 + 0.5, f"saved_images/photo_{epoch}_{idx}.png")
            save_image(fake_Monet * 0.5 + 0.5, f"saved_images/monet_{epoch}_{idx}.png")

In [10]:
def validate_fn(disc_P, disc_M, gen_P, gen_M, loader, l1, mse, epoch):
    # Set models to evaluation mode
    disc_P.eval()
    disc_M.eval()
    gen_P.eval()
    gen_M.eval()

    total_D_loss = 0.0
    total_G_loss = 0.0
    num_batches = len(loader)

    # Disable gradient computations for validation
    with torch.no_grad():
        loop = tqdm(loader, leave=True)
        for idx, (monet, photo) in enumerate(loop):
            photo = photo.to(DEVICE)
            monet = monet.to(DEVICE)
            # --------------------
            #  Discriminator Loss
            # --------------------
            with torch.amp.autocast('cuda'):
                fake_Photo = gen_P(monet)
                D_P_real = disc_P(photo)
                D_P_fake = disc_P(fake_Photo)
                D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
                D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))
                D_P_loss = D_P_real_loss + D_P_fake_loss

                fake_Monet = gen_M(photo)
                D_M_real = disc_M(monet)
                D_M_fake = disc_M(fake_Monet)
                D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
                D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
                D_M_loss = D_M_real_loss + D_M_fake_loss

                D_loss = (D_P_loss + D_M_loss) / 2

            # --------------------
            #  Generator Loss
            # --------------------
            with torch.amp.autocast('cuda'):
                # Re-compute for generators (if needed for validation metrics)
                D_P_fake = disc_P(fake_Photo)
                D_M_fake = disc_M(fake_Monet)
                Loss_G_P = mse(D_P_fake, torch.ones_like(D_P_fake))
                Loss_G_M = mse(D_M_fake, torch.ones_like(D_M_fake))

                # Cycle consistency
                cycle_monet = gen_M(fake_Photo)
                cycle_photo = gen_P(fake_Monet)
                cycle_monet_loss = l1(cycle_monet, monet)
                cycle_photo_loss = l1(cycle_photo, photo)

                # Identity loss
                identity_photo = gen_P(photo)
                identity_monet = gen_M(monet)
                identity_monet_loss = l1(identity_photo, photo)
                identity_photo_loss = l1(identity_monet, monet)

                G_loss = (Loss_G_P + Loss_G_M +
                          cycle_monet_loss * LAMBDA_CYCLE +
                          cycle_photo_loss * LAMBDA_CYCLE +
                          identity_monet_loss * LAMBDA_IDENTITY +
                          identity_photo_loss * LAMBDA_IDENTITY)
            # --------------------
            #  CMMD
            # --------------------
            with torch.no_grad():
                fake_Monet = fake_Monet.squeeze(0).cpu()  # Remove batch dimension
                fake_Photo = fake_Photo.squeeze(0).cpu()
                # Convert back to image format
                save_image_Monet = transforms.ToPILImage()(fake_Monet * 0.5 + 0.5)  # Denormalize
                save_image_Photo = transforms.ToPILImage()(fake_Photo * 0.5 + 0.5)
                save_image_Monet.save(f'Output/Monet/{epoch}/{idx}.jpg')
                save_image_Photo.save(f'Output/Photo/{epoch}/{idx}.jpg')
            total_D_loss += D_loss.item()
            total_G_loss += G_loss.item()

    avg_D_loss = total_D_loss / num_batches
    avg_G_loss = total_G_loss / num_batches
    # print(f"Epoch {epoch} | Validation D Loss: {avg_D_loss:.4f}, G Loss: {avg_G_loss:.4f}")

    # Set models back to training mode
    disc_P.train()
    disc_M.train()
    gen_P.train()
    gen_M.train()

    return avg_D_loss, avg_G_loss

<font size="5"> 
<b>
CMMD code
</b>
</font>

CMMD is used to evaluate the performance of the code. The code was taken from __[github](https://github.com/sayakpaul/cmmd-pytorch)__, written by Sayak Paul and Agneet Chatterjee.

In [11]:
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Memory-efficient MMD implementation in JAX."""

import torch

# The bandwidth parameter for the Gaussian RBF kernel. See the paper for more
# details.
_SIGMA = 10
# The following is used to make the metric more human readable. See the paper
# for more details.
_SCALE = 1000


def mmd(x, y):
    """Memory-efficient MMD implementation in JAX.

    This implements the minimum-variance/biased version of the estimator described
    in Eq.(5) of
    https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf.
    As described in Lemma 6's proof in that paper, the unbiased estimate and the
    minimum-variance estimate for MMD are almost identical.

    Note that the first invocation of this function will be considerably slow due
    to JAX JIT compilation.

    Args:
      x: The first set of embeddings of shape (n, embedding_dim).
      y: The second set of embeddings of shape (n, embedding_dim).

    Returns:
      The MMD distance between x and y embedding sets.
    """
    x = torch.from_numpy(x)
    y = torch.from_numpy(y)

    x_sqnorms = torch.diag(torch.matmul(x, x.T))
    y_sqnorms = torch.diag(torch.matmul(y, y.T))

    gamma = 1 / (2 * _SIGMA**2)
    k_xx = torch.mean(
        torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0)))
    )
    k_xy = torch.mean(
        torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))
    )
    k_yy = torch.mean(
        torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0)))
    )

    return _SCALE * (k_xx + k_yy - 2 * k_xy)


In [12]:
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Embedding models used in the CMMD calculation."""

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
import torch
import numpy as np

_CLIP_MODEL_NAME = "openai/clip-vit-large-patch14-336"
_CUDA_AVAILABLE = torch.cuda.is_available()


def _resize_bicubic(images, size):
    images = torch.from_numpy(images.transpose(0, 3, 1, 2))
    images = torch.nn.functional.interpolate(images, size=(size, size), mode="bicubic")
    images = images.permute(0, 2, 3, 1).numpy()
    return images


class ClipEmbeddingModel:
    """CLIP image embedding calculator."""

    def __init__(self):
        self.image_processor = CLIPImageProcessor.from_pretrained(_CLIP_MODEL_NAME)

        self._model = CLIPVisionModelWithProjection.from_pretrained(_CLIP_MODEL_NAME).eval()
        if _CUDA_AVAILABLE:
            self._model = self._model.cuda()

        self.input_image_size = self.image_processor.crop_size["height"]

    @torch.no_grad()
    def embed(self, images):
        """Computes CLIP embeddings for the given images.

        Args:
          images: An image array of shape (batch_size, height, width, 3). Values are
            in range [0, 1].

        Returns:
          Embedding array of shape (batch_size, embedding_width).
        """

        images = _resize_bicubic(images, self.input_image_size)
        inputs = self.image_processor(
            images=images,
            do_normalize=True,
            do_center_crop=False,
            do_resize=False,
            do_rescale=False,
            return_tensors="pt",
        )
        if _CUDA_AVAILABLE:
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        image_embs = self._model(**inputs).image_embeds.cpu()
        image_embs /= torch.linalg.norm(image_embs, axis=-1, keepdims=True)
        return image_embs


In [13]:
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""IO utilities."""

import glob
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import tqdm


class CMMDDataset(Dataset):
    def __init__(self, path, reshape_to, max_count=-1):
        self.path = path
        self.reshape_to = reshape_to

        self.max_count = max_count
        img_path_list = self._get_image_list()
        if max_count > 0:
            img_path_list = img_path_list[:max_count]
        self.img_path_list = img_path_list

    def __len__(self):
        return len(self.img_path_list)

    def _get_image_list(self):
        ext_list = ["png", "jpg", "jpeg"]
        image_list = []
        for ext in ext_list:
            image_list.extend(glob.glob(f"{self.path}/*{ext}"))
            image_list.extend(glob.glob(f"{self.path}/*.{ext.upper()}"))
        # Sort the list to ensure a deterministic output.
        image_list.sort()
        return image_list

    def _center_crop_and_resize(self, im, size):
        w, h = im.size
        l = min(w, h)
        top = (h - l) // 2
        left = (w - l) // 2
        box = (left, top, left + l, top + l)
        im = im.crop(box)
        # Note that the following performs anti-aliasing as well.
        return im.resize((size, size), resample=Image.BICUBIC)  # pytype: disable=module-attr

    def _read_image(self, path, size):
        im = Image.open(path)
        if size > 0:
            im = self._center_crop_and_resize(im, size)
        return np.asarray(im).astype(np.float32)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]

        x = self._read_image(img_path, self.reshape_to)
        if x.ndim == 3:
            return x
        elif x.ndim == 2:
            # Convert grayscale to RGB by duplicating the channel dimension.
            return np.tile(x[Ellipsis, np.newaxis], (1, 1, 3))


def compute_embeddings_for_dir(
    img_dir,
    embedding_model,
    batch_size,
    max_count=-1,
):
    """Computes embeddings for the images in the given directory.

    This drops the remainder of the images after batching with the provided
    batch_size to enable efficient computation on TPUs. This usually does not
    affect results assuming we have a large number of images in the directory.

    Args:
      img_dir: Directory containing .jpg or .png image files.
      embedding_model: The embedding model to use.
      batch_size: Batch size for the embedding model inference.
      max_count: Max number of images in the directory to use.

    Returns:
      Computed embeddings of shape (num_images, embedding_dim).
    """
    dataset = CMMDDataset(img_dir, reshape_to=embedding_model.input_image_size, max_count=max_count)
    count = len(dataset)
    print(f"Calculating embeddings for {count} images from {img_dir}.")

    dataloader = DataLoader(dataset, batch_size=batch_size)

    all_embs = []
    for batch in tqdm.tqdm(dataloader, total=count // batch_size):
        image_batch = batch.numpy()

        # Normalize to the [0, 1] range.
        image_batch = image_batch / 255.0

        if np.min(image_batch) < 0 or np.max(image_batch) > 1:
            raise ValueError(
                "Image values are expected to be in [0, 1]. Found:" f" [{np.min(image_batch)}, {np.max(image_batch)}]."
            )

        # Compute the embeddings using a pmapped function.
        embs = np.asarray(
            embedding_model.embed(image_batch)
        )  # The output has shape (num_devices, batch_size, embedding_dim).
        all_embs.append(embs)

    all_embs = np.concatenate(all_embs, axis=0)

    return all_embs


In [14]:
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The main entry point for the CMMD calculation."""

from absl import app
from absl import flags
import distance
import embedding
import io_util
import numpy as np


_BATCH_SIZE = flags.DEFINE_integer("batch_size", 32, "Batch size for embedding generation.")
_MAX_COUNT = flags.DEFINE_integer("max_count", -1, "Maximum number of images to read from each directory.")
_REF_EMBED_FILE = flags.DEFINE_string(
    "ref_embed_file", None, "Path to the pre-computed embedding file for the reference images."
)


def compute_cmmd(ref_dir, eval_dir, ref_embed_file=None, batch_size=32, max_count=-1):
    """Calculates the CMMD distance between reference and eval image sets.

    Args:
      ref_dir: Path to the directory containing reference images.
      eval_dir: Path to the directory containing images to be evaluated.
      ref_embed_file: Path to the pre-computed embedding file for the reference images.
      batch_size: Batch size used in the CLIP embedding calculation.
      max_count: Maximum number of images to use from each directory. A
        non-positive value reads all images available except for the images
        dropped due to batching.

    Returns:
      The CMMD value between the image sets.
    """
    if ref_dir and ref_embed_file:
        raise ValueError("`ref_dir` and `ref_embed_file` both cannot be set at the same time.")
    embedding_model = embedding.ClipEmbeddingModel()
    if ref_embed_file is not None:
        ref_embs = np.load(ref_embed_file).astype("float32")
    else:
        ref_embs = io_util.compute_embeddings_for_dir(ref_dir, embedding_model, batch_size, max_count).astype(
            "float32"
        )
    eval_embs = io_util.compute_embeddings_for_dir(eval_dir, embedding_model, batch_size, max_count).astype("float32")
    val = distance.mmd(ref_embs, eval_embs)
    return val.numpy()


def main(argv):
    if len(argv) != 3:
        raise app.UsageError("Too few/too many command-line arguments.")
    _, dir1, dir2 = argv
    print(
        "The CMMD value is: "
        f" {compute_cmmd(dir1, dir2, _REF_EMBED_FILE.value, _BATCH_SIZE.value, _MAX_COUNT.value):.3f}"
    )


if __name__ == "__main__":
    app.run(main)


FATAL Flags parsing error: Unknown command line flag 'f'
Pass --helpshort or --helpfull to see help on flags.


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


<font size="5"> 
<b>
Main
</b>
</font>

We define all relevant discriminators, generators, optimizers, losses etc. and run the training for an `NUM_EPOCHS` amount of times. If we set `SAVE_MODEL` to `true`, we also keep a copy of every epoch (for example, for debugging, or to take a previous version in case the model started to overfit)

In [None]:
def main():
    disc_P = Discriminator(in_channels=3).to(DEVICE)
    disc_M = Discriminator(in_channels=3).to(DEVICE)
    gen_P = Generator(img_channels=3, num_residuals=9). to (DEVICE)
    gen_M = Generator(img_channels=3, num_residuals=9). to (DEVICE)
    opt_disc = optim.Adam(
        list(disc_P.parameters()) + list(disc_M.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_P.parameters()) + list(gen_M.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    # These checkpoint files allow the training process to resume from where it left off, without starting over from scratch.
    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_P,
            gen_P,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_M,
            gen_M,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_P,
            disc_P,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_M,
            disc_M,
            opt_disc,
            LEARNING_RATE,
        )
    
    train_dataset = PhotoMonetDataset(
        root_photo=TRAIN_DIR + "/Photo",
        root_monet=TRAIN_DIR + "/Monet",
        transform=transforms,
    )
    val_dataset = PhotoMonetDataset(
         root_photo=VAL_DIR + "/Photo",
         root_monet=VAL_DIR + "/Monet",
         transform=transforms,
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.amp.GradScaler('cuda')
    d_scaler = torch.amp.GradScaler('cuda')

    G_loss = []
    D_loss = []
    CMMD_Monet = []
    CMMD_Photo = []

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch}:")

        print("Training...")
        train_fn(
            disc_P,
            disc_M,
            gen_P,
            gen_M,
            train_loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
            epoch
        )

        print("Validation...")
        new_d_loss, new_g_loss = validate_fn(
            disc_P,
            disc_M,
            gen_P,
            gen_M,
            val_loader,
            L1,
            mse,
        )
        
        new_CMMD_Monet= compute_cmmd('Data/val/Monet', f'Output/Monet/{epoch}')
        new_CMMD_Photo= compute_cmmd('Data/val/Photo', f'Output/Photo/{epoch}')
        

        G_loss.append(new_g_loss)
        D_loss.append(new_d_loss)

        CMMD_Monet.append(new_CMMD_Monet)
        CMMD_Photo.append(new_CMMD_Photo)

        print("G loss: ", G_loss)
        print("D loss:", D_loss)
        print("CMMD Monet:", CMMD_Monet)
        print("CMMD Photo:", CMMD_Photo)

        if SAVE_MODEL:
            save_checkpoint(gen_P, opt_gen, filename=CHECKPOINT_GEN_P)
            save_checkpoint(gen_M, opt_gen, filename=CHECKPOINT_GEN_M)
            save_checkpoint(disc_P, opt_disc, filename=CHECKPOINT_CRITIC_P)
            save_checkpoint(disc_M, opt_disc, filename=CHECKPOINT_CRITIC_M)
    print("G loss: ", G_loss)
    print("D loss: ", D_loss)
    print("CMMD Monet:", CMMD_Monet)
    print("CMMD Photo:", CMMD_Photo)

In [None]:
main()



Epoch 0:
Training...


  0%|          | 0/6938 [00:00<?, ?it/s]