In [1]:
# Imports
import os
import torch
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm

In [2]:
# Generator class
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )

        # ResNet blocks
        self.resnet_blocks = nn.Sequential(*[ResBlock(256) for _ in range(9)])

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=0),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.resnet_blocks(x)
        x = self.decoder(x)
        return x

In [3]:
transform = transforms.Compose([
    transforms.Resize((256,256), interpolation=Image.BICUBIC),transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])



In [4]:
input_path = 'C:/Users/itayg/Documents/MonetGAN/Data/monet_jpg'
output_path = 'C:/Users/itayg/Documents/MonetGAN/Data/monet_generated_no_VGG'

In [None]:
input_path = 'C:/Users/itayg/Documents/MonetGAN/Data/photo_jpg'
output_path = 'C:/Users/itayg/Documents/MonetGAN/data/monet_generated_no_VGG/class'

In [None]:
def denormalize(tensor):
    """
    Converts a tensor from [-1, 1] to [0, 1] for display.
    Expects tensor shape (C, H, W) or (B, C, H, W).
    """
    tensor = tensor * 0.5 + 0.5
    return tensor.clamp(0, 1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G_A2B = Generator().to(device)
G_A2B.load_state_dict(torch.load('generator_photo2monet_epoch195_identity_1.5.pth'))
G_A2B.eval();

In [None]:
images = os.listdir(input_path)[:300]

for img in tqdm(images):
    if img.lower().endswith(('.jpg')):
        img_path = os.path.join(input_path, img)

        image = Image.open(img_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output_tensor = G_A2B(input_tensor)

        output_tensor = denormalize(output_tensor.squeeze(0)).clamp(0, 1)
        save_image(output_tensor, os.path.join(output_path, img))

print ('All images were generated and saved in:', output_path)

In [None]:
import torch
import matplotlib.pyplot as plt

def show_generated_samples(generator, real_photo, epoch):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():
        fake_monet = generator(real_photo.to(device))  # Generate Monet-style image

    # Get first image in batch
    real_photo = real_photo[0].cpu().detach()
    fake_monet = fake_monet[0].cpu().detach()

    # De-normalize from [-1, 1] → [0, 1]
    real_photo = denormalize(real_photo).permute(1, 2, 0)
    fake_monet = denormalize(fake_monet).permute(1, 2, 0)

    # Plot images side by side
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    ax[0].imshow(real_photo.numpy())
    ax[0].set_title("Real Photo")
    ax[0].axis("off")

    ax[1].imshow(fake_monet.numpy())
    ax[1].set_title(f"Generated Monet (Epoch {epoch})")
    ax[1].axis("off")

    plt.tight_layout()

In [None]:
import sys
!{sys.executable} -m pip install torchmetrics


In [None]:
torch.cuda.empty_cache()

In [6]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
import os

# ----------- FIX 1: Typo in "cuda" -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------- FIX 2: Incorrect "transforms.compose" (should be Compose with capital C) -----------
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x*255).byte())
])

# ----------- Datasets -----------
real_dataset = datasets.ImageFolder(root=input_path, transform=transform)
fake_dataset = datasets.ImageFolder(root=output_path,transform=transform)
# ----------- Dataloaders -----------
real_loader = DataLoader(real_dataset, batch_size=16, shuffle=False)
fake_loader = DataLoader(fake_dataset, batch_size=16, shuffle=False)

fid = FrechetInceptionDistance(feature=2048).to(device)


# ----------- Process real images -----------
for imgs, _ in real_loader:
    imgs = imgs.to(device)
    fid.update(imgs, real=True)

# ----------- Process fake images -----------
for imgs, _ in fake_loader:
    imgs = imgs.to(device)
    fid.update(imgs, real=False)

# ----------- Calculate FID -----------
fid_score = fid.compute()

print(f"FID Score: {fid_score.item():.4f}")


FID Score: 127.2188
