In [6]:
# Imports
import os
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance

In [7]:
# Generator class
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )

        # ResNet blocks
        self.resnet_blocks = nn.Sequential(*[ResBlock(256) for _ in range(9)])

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.resnet_blocks(x)
        x = self.decoder(x)
        return x

In [11]:
input_path = 'C:/Users/itayg/Documents/MonetGAN/Data/monet_jpg'
output_path = 'C:/Users/itayg/Documents/MonetGAN/data/monet_generated_no_VGG'

In [12]:
torch.cuda.empty_cache()

In [13]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x*255).byte())
])

# Datasets
real_dataset = datasets.ImageFolder(root=input_path, transform=transform)
fake_dataset = datasets.ImageFolder(root=output_path,transform=transform)
# Dataloaders
real_loader = DataLoader(real_dataset, batch_size=16, shuffle=False)
fake_loader = DataLoader(fake_dataset, batch_size=16, shuffle=False)

fid = FrechetInceptionDistance(feature=2048).to(device)


# Process real images 
for imgs, _ in real_loader:
    imgs = imgs.to(device)
    fid.update(imgs, real=True)

# Process fake images 
for imgs, _ in fake_loader:
    imgs = imgs.to(device)
    fid.update(imgs, real=False)

# Calculate FID
fid_score = fid.compute()

print(f"FID Score: {fid_score.item():.4f}")

FID Score: 127.2188
