In [2]:
import os
from torch.utils.data import DataLoader
## referance links
import os
import numpy as np
import random
import math
import scipy
import torch
import torchvision
import torchvision.utils as vutils
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
import wandb
from fid_metric import fid_metric as fid
import torchvision.datasets as datasets
from torch.utils.data import Dataset
import natsort
from PIL import Image
import glob

train_cat_path = "D:/masa üstü/Hamza Proje Dosyalar/comp511-project/afhq/train/cat"
train_dog_path = "D:/masa üstü/Hamza Proje Dosyalar/comp511-project/afhq/train/dog"
val_cat_path = "D:/masa üstü/Hamza Proje Dosyalar/comp511-project/afhq/val/cat"
val_dog_path = "D:/masa üstü/Hamza Proje Dosyalar/comp511-project/afhq/val/dog"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
manual_seed = 10
num_epochs = 200
decay_start_epoch = 5
lr = 0.0002
random.seed(manual_seed)
torch.manual_seed(manual_seed)


def read_images(path):
    group1_images = []
    group2_images = []
    # Iterate through the directory containing the images
    for root, dirs, files in os.walk(path):
        for file in files:
            # Get the full path of the image
            image_path = os.path.join(root, file)
            # Check the folder name and add the image to the appropriate list
            if "cat" in root:
                group1_images.append(image_path)
            elif "dog" in root:
                group2_images.append(image_path)
    return group1_images, group2_images

batch_size = 1  ## 128
num_workers = 2

transform = transforms.Compose([transforms.Resize((64,64)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])

class CustomDataSet(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)


    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image

In [3]:
train_cat = CustomDataSet(train_cat_path, transform=transform)
train_dog = CustomDataSet(train_dog_path, transform=transform)
val_cat = CustomDataSet(val_cat_path, transform=transform)
val_dog = CustomDataSet(val_dog_path, transform=transform)

In [4]:
# Create a dataloader for each group of images
load_train_A = DataLoader(train_cat, batch_size=batch_size, shuffle=True)
load_train_B = DataLoader(train_dog, batch_size=batch_size, shuffle=True)

load_test_A = DataLoader(val_cat, batch_size=batch_size, shuffle=True)
load_test_B = DataLoader(val_dog, batch_size=batch_size, shuffle=True)

In [5]:
def save_iter(real_A,real_B, fake_B, i, epoch):
    # save the images
    vutils.save_image(real_A,
            'cat_epoch_'+str(epoch)+"_iter_"+str(i)+'.png',
            normalize=True)

    vutils.save_image(real_B,
            'dog_epoch_'+str(epoch)+"_iter_"+str(i)+'.png',
            normalize=True)

    vutils.save_image(fake_B,
            'generated_dog_epoch_'+str(epoch)+"_iter_"+str(i)+'.png',
            normalize=True)

In [6]:
# define the generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(3, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, input):
        return self.main(input)

# # define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

# initialize the generator and discriminator
netG = Generator().to(device)
netD = Discriminator().to(device)

### design a CYCLEGAN model reference https://nn.labml.ai/gan/cycle_gan/index.html
class CycleGAN(nn.Module):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, x):
        return self.generator(x)

    def gen_loss(self, x):
        # The generator loss
        return -self.discriminator(self(x)).mean()

    def disc_loss(self, x, y):
        # The discriminator loss
        return -(self.discriminator(x).mean() - self.discriminator(y).mean())

    def cycle_loss(self, x, y):
        # The cycle consistency loss
        return (x - self(y)).abs().mean()

    def identity_loss(self, x, y):
        # The identity loss
        return (x - y).abs().mean()

# initialize the model
model = CycleGAN(netG, netD)

# define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))

# train the model
def train():
    vgg16 = models.vgg16(pretrained=True)

    # remove the last layers (classifier)
    vgg16 = torch.nn.Sequential(*list(vgg16.features)[:-1], vgg16.avgpool)
    vgg16.eval()
    feature_model = vgg16.to(device) 
    
    for epoch in range(num_epochs):
        for i, (real_A, real_B) in enumerate(zip(load_train_A, load_train_B)):
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # train the discriminator
            optimizer.zero_grad()
            fake_B = model(real_A)
            fake_A = model(real_B)

            # train the generator
            optimizer.zero_grad()
            loss_G = model.gen_loss(real_A) + model.gen_loss(real_B) + model.cycle_loss(real_A, fake_B) + model.cycle_loss(real_B, fake_A) + model.identity_loss(real_A, real_B) + model.identity_loss(real_B, real_A)
            loss_G.backward()
            optimizer.step()

            # train the discriminator
            optimizer.zero_grad()
            loss_D = model.disc_loss(real_A, fake_A) + model.disc_loss(real_B, fake_B)
            loss_D.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss_D: {:.4f}, Loss_G: {:.4f}'
                    .format(epoch + 1, num_epochs, i + 1, len(load_train_A), loss_D.item(), loss_G.item()))

                wandb.log({"loss_D": loss_D.item(), "loss_G": loss_G.item()})

                fid_score = fid(feature_model, real_B, fake_B)
                #kid = calculate_kid(real_B, fake_B)
                wandb.log({"FID": fid_score})
                
                #wandb.log({'grad_penalty': grad_penalty})
                wandb.log({"epoch": epoch})


                wandb.log({"real_A": [wandb.Image(real_A[0], caption="real_sample")],
                            "real_B": [wandb.Image(real_B[0], caption="reference_sample")],
                            "fake_B": [wandb.Image(fake_B[0], caption="generated_sample")]})
                 

        # save the images
        save_iter(real_A, real_B, fake_B, i, epoch)

        # save the model
        torch.save(model.state_dict(), 'model_epoch_'+str(epoch)+'.pth')


In [7]:
# initialize wandb
torch.cuda.empty_cache()
wandb.init(project="CGAN_TRANSLATION", entity="comp511")
# train the model
train()

## run the script in terminal
## python cgan_translation.py
## It is observable in wandb page

[34m[1mwandb[0m: Currently logged in as: [33mgkecibas16[0m ([33mcomp511[0m). Use [1m`wandb login --relogin`[0m to force relogin


: 

: 