In [1]:
from matplotlib import pyplot as plt

def show_img(tensors):
    plt.figure(figsize=(15, 15))
        
    fig, ax = plt.subplots(len(tensors)//3 + 1, 3)
    for index, tensor in enumerate(tensors):
        ax[index].imshow(tensor.squeeze().permute(1,2,0).detach().numpy())
    plt.show()
            

# Data

In [2]:
# load adatasets
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt
import tensorflow as tf

class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('png', 'jpg', 'jpeg'))]
        self.cache = dict()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):        
        img_path = self.image_paths[idx]
        if img_path in self.cache:
            image = self.cache[img_path]
        else:
            image = Image.open(img_path).convert("RGB")
            self.cache[img_path] = image
        if self.transform:
            image = self.transform(image)
        return image


2024-06-09 12:17:45.218030: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-09 12:17:45.218225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-09 12:17:45.384413: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
dataset_monet = ImageFolderDataset('/kaggle/input/gan-getting-started/monet_jpg',transforms.Compose([
        transforms.Resize((256, 256)),  # Resize the image to 256x256
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
]))
dataset_real = ImageFolderDataset('/kaggle/input/gan-getting-started/photo_jpg',transforms.Compose([
        transforms.Resize((256, 256)),  # Resize the image to 256x256
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
]))

# Model

In [4]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv_1 = nn.Conv2d(
            in_channels = in_c,
            out_channels = out_c,
            kernel_size = 3,
            padding=1
        )
        self.relu_1 = nn.ReLU()
        self.conv_2 = nn.Conv2d(
            in_channels = out_c,
            out_channels = out_c,
            kernel_size = 3,
            padding=1
        )
        self.relu_2 = nn.ReLU()
        self.conv_3 = nn.Conv2d(
            in_channels = out_c,
            out_channels = out_c,
            kernel_size = 3,
            padding=1
        )
        self.relu_3 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu_1(x)
        x = self.conv_2(x)
        x = self.relu_2(x)
        x = self.conv_3(x)
        x = self.relu_3(x)
        return x
    
class EncoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2,2))
    
    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p
    
class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c*2, out_c)
        
    def forward(self, x, residual):
        x = self.up(x)        
        x = torch.cat([x, residual], axis=1)
        x = self.conv(x)
        return x

In [5]:
class UNET_Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_1 = EncoderBlock(3, 64) # (64, 128, 128)
        self.enc_2 = EncoderBlock(64, 128) # (128, 64, 64)
        self.enc_3 = EncoderBlock(128, 256) # (256, 32, 32)
        self.enc_4 = EncoderBlock(256, 512) # (512, 16, 16)
        
        self.bridge = conv_block(512, 1024) # (1024, 16, 16)
        
        self.dec_1 = DecoderBlock(1024, 512) # (512, 32, 32)
        self.dec_2 = DecoderBlock(512, 256) # (256, 64, 64)
        self.dec_3 = DecoderBlock(256, 128) # (128, 128, 128)
        self.dec_4 = DecoderBlock(128, 64) # (64, 256, 256)
        
        self.final = nn.Conv2d(64, 3, kernel_size=1, padding=0)
        self.activation = nn.Tanh()
        
    def forward(self, x):
        x_enc_1, x = self.enc_1(x)        
        x_enc_2, x = self.enc_2(x)        
        x_enc_3, x = self.enc_3(x)        
        x_enc_4, x = self.enc_4(x)
        
        bridge = self.bridge(x)
        
        result = self.dec_1(bridge, x_enc_4)  
        result = self.dec_2(result, x_enc_3)
        result = self.dec_3(result, x_enc_2)
        result = self.dec_4(result, x_enc_1)
        
        result = self.final(result)
        result = self.activation(result)
        return result
        

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
        self.resnet50 = torch.nn.Sequential(*(list(self.resnet50.children())[:-1]))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(2048, 1024)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()
        
        for param in self.resnet50.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        result = self.resnet50(x)
        result = self.flatten(result)
        result = self.fc1(result)
        result = self.relu(result)
        result = self.fc2(result)
        result = self.sigmoid(result)
        return result
        

# Training

In [7]:
torch.cat([torch.ones(10), torch.zeros(10)])

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

In [8]:
def discriminator_loss(real_img_desc_pred, gen_img_desc_pred):
    target_positive = torch.ones(len(real_img_desc_pred))
    target_negative = torch.zeros(len(gen_img_desc_pred))
    pred = torch.cat(real_img_desc_pred, gen_img_desc_pred)
    target = torch.cat(target_positive, target_negative)
    
    

In [9]:
x_y_generator = UNET_Generator()
x_y_discriminator = Discriminator()

y_x_generator = UNET_Generator()
y_x_discriminator = Discriminator()

Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip
Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/resnet50_pyt_amp/versions/20.06.0/files/nvidia_resnet50_200821.pth.tar" to /root/.cache/torch/hub/checkpoints/nvidia_resnet50_200821.pth.tar
100%|██████████| 97.7M/97.7M [00:04<00:00, 23.0MB/s]
Using cache found in /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub
