In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from PIL import Image
import random
import urllib.request
import numpy as np
import cv2
from sklearn.model_selection import train_test_split

: 

In [75]:
def get_file_dataset(dir):
    images = []

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if os.path.isfile(os.path.join(dir, fname)) and fname.endswith('.jpg'):
                path = os.path.join(root, fname)
                images.append(path)

    return images

In [76]:
class ImageSketchDataset(Dataset):
    def __init__(self, image_dir, sketch_dir):
        self.image_dir = image_dir
        self.sketch_dir = sketch_dir

        self.image_files = sorted(get_file_dataset(image_dir))
        self.sketch_files = sorted(get_file_dataset(sketch_dir))

        # transform_list = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] # RGB
        transform_list = [transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]
        self.transform = transforms.Compose(transform_list)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, i):
        image_path = self.image_files[i]
        sketch_path = self.image_files[i]

        image = Image.open(image_path).convert('L')
        sketch = Image.open(sketch_path).convert('L')
        
        return {'image': self.transform(image), 'sketch': self.transform(sketch), 'image_path': image_path, 'sketch_path': sketch_path}

In [77]:
# Set the path to the folder containing the images
img_folder = './photos/'
sketch_folder = './sketches/'

dataset = ImageSketchDataset(img_folder, sketch_folder)

# Split dataset
train_set, val_set = train_test_split(dataset, test_size=0.2, random_state=42)

# Params
batch_size = 32

# Create DataLoader objects for the training and validation sets
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.fake_label_var = None
        self.real_label_var = None
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCEWithLogitsLoss()

    def get_target_tensor(self, input_img, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input_img.numel()))
            if create_label:
                real_tensor = self.Tensor(input_img.size()).fill_(self.real_label)
                self.real_label_var = nn.Parameter(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input_img.numel()))
            if create_label:
                fake_tensor = self.Tensor(input_img.size()).fill_(self.fake_label)
                self.fake_label_var = nn.Parameter(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def forward(self, input_img, target_is_real):
        target_tensor = self.get_target_tensor(input_img, target_is_real)
        return self.loss(input_img, target_tensor.expand_as(input_img))

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self, layers, weights, cuda):
        super(PerceptualLoss, self).__init__()
        self.layers = layers
        self.weights = weights
        self.vgg = models.VGG19(pretrained=True).features
        if cuda:
            self.vgg = self.vgg.cuda()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i, layer in enumerate(self.layers):
            loss += self.weights[i] * F.l1_loss(x_vgg[layer], y_vgg[layer])
        return loss

In [None]:
class CycleConsistencyLoss(nn.Module):
    def __init__(self):
        super(CycleConsistencyLoss, self).__init__()

    def forward(self, x, y, x_recon, y_recon):
        return F.l1_loss(x, x_recon) + F.l1_loss(y, y_recon)

In [None]:
class IdentityLoss(nn.Module):
    def __init__(self):
        super(IdentityLoss, self).__init__()

    def forward(self, x, y):
        return F.l1_loss(x, y)

In [None]:
class PatchLoss(nn.Module):
    def __init__(self):
        super(PatchLoss, self).__init__()
        self.loss = nn.L1Loss()

    def get_masked_images(self, input, gt, gtsegmap, label):
        i, j = np.where(gtsegmap == label)
        mask = torch.zeros_like(gtsegmap, dtype=torch.float32)
        mask[i, j] = 1.0
        input_m = input * mask.cuda()
        gt_m = gt * mask.cuda()
        return input_m, gt_m

    def forward(self, input, gt, gtsegmap, label):
        gtsegmap = (gtsegmap + 1) * 5.0
        input_m, gt_m = self.get_masked_images(input, gt, gtsegmap, label)
        gt_m = gt_m.detach()
        return self.loss(input_m, gt_m)

In [None]:
class FeatureMatchingLoss(nn.Module):
    def __init__(self, device):
        super(FeatureMatchingLoss, self).__init__()
        self.vgg = VGGNet().to(device)
        self.criterion = nn.L1Loss()

    def forward(self, real_img, fake_img, D):
        # Extract features from intermediate layers of the VGG network
        real_features = self.vgg(real_img, D)
        fake_features = self.vgg(fake_img, D)

        # Calculate the L1 loss between the real and fake features
        loss = 0
        for i in range(len(real_features)):
            loss += self.criterion(real_features[i], fake_features[i].detach())

            return loss

class VGGNet(nn.Module):
        def __init__(self):
            super(VGGNet, self).__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer2 = nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer3 = nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer4 = nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer5 = nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2))
            self.layer6 = nn.Sequential(
                nn.Linear(25088, 4096),
                nn.ReLU(),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Dropout(),
                nn.Linear(4096, 1000),
                nn.ReLU())

        def forward(self, x, D):
            if D == 0:
                x = self.layer1(x)
                x = self.layer2(x)
                x = self.layer3(x)
                x = self.layer4(x)
                x = self.layer5(x)
            else:
                for i in range(D):
                    x = getattr(self, 'layer{}'.format(i + 1))(x)
            x = x.view(x.size(0), -1)
            x = self.layer6(x)
            return x

In [None]:
class Generator(nn.Module):
  def __init__(self):

    super(Generator, self).__init__()

    self.first_stage = torch.nn.Conv2d(224, 224, 32)

    self.second_stage = torch.nn.Conv2d(28, 28, 256)
    self.third_stage = torch.nn.Conv2d(56, 56, 128)
    self.fourth_stage = torch.nn.Conv2d(112, 112, 64)
    one = torch.nn.Conv2d(1,1)
    three = torch.nn.Conv2d(3,3)
    m = torch.nn.Upsample(scale_factor=2, mode='bilinear')


  def forward(self, img):
    feature_maps = []

    output_img = torch.zeros(())

    for i in range(1, N - 1):
      f_m = feature_maps[i]
      upsampled_fm = m(f_m)
      f_m_1 = feature_maps[i + 1]
      upsampled_fm = torch.cat((f_m, upsampled_fm))
      ones_conv = one(upsampled_fm)
      threes_conv = three(ones_conv)

      feature_maps[i + 1] =  threes_conv


    return feature_maps[-1]




In [None]:
class Discriminator(nn.Module):
  def __init__(self):

    self.first_layer = torch.nn.Conv2d(3, 32, 3, stride = 2, padding = 1)
    self.batch1 = torch.nn.BatchNorm2d(32, eps=1e-10)
    self.second_layer = torch.nn.Conv2d(32, 32, 3, stride = 2, padding = 1)
    self.batch2 = torch.nn.BatchNorm2d(32, eps=1e-10)
    self.third_layer = torch.nn.Conv2d(32, 64, 3, stride = 2, padding = 1)
    self.batch3 = torch.nn.BatchNorm2d(64, eps=1e-10)
    self.fourth_layer = torch.nn.Conv2d(64, 64, 3, stride = 2, padding = 1)
    self.batch4 = torch.nn.BatchNorm2d(64, eps=1e-10)
    self.fifth_layer = torch.nn.Conv2d(64, 128, 3, stride = 2, padding = 1)
    self.batch5 = torch.nn.BatchNorm2d(128, eps=1e-10)
    self.sixth_layer = torch.nn.Conv2d(128, 1, 3, stride = 1, padding = 1)
    self.batch6 = torch.nn.BatchNorm2d(1, eps=1e-10)
    self.sigmoid = torch.nn.sigmoid()

  def forward(self, img):
    first = F.relu(self.batch1(self.first_layer(img)))
    second = F.relu(self.batch2(self.second_layer(first)))
    third = F.relu(self.batch3(self.third_layer(second)))
    fourth = F.relu(self.batch4(self.fourth_layer(third)))
    fifth = F.relu(self.batch5(self.fifth_layer(fourth)))
    sixth = F.relu(self.batch6(self.sixth_layer(fifth)))
    output = self.sigmoid(sixth)


    return output
