In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from PIL import Image

import itertools

from MangaDataset import MangaDataset
from utils import ReplayBuffer
from utils import LambdaLR
#from utils import Logger
from utils import weights_init_normal

In [2]:
# https://github.com/aitorzip/PyTorch-CycleGAN/tree/master
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# If 'cuda' is printed, it means GPU is available

cuda


In [4]:
netG_color_to_gray = Generator(3, 3).to(device)
netG_gray_to_color = Generator(3, 3).to(device)
netD_color = Discriminator(3).to(device)
netD_gray = Discriminator(3).to(device)

In [5]:
netG_color_to_gray.apply(weights_init_normal)
netG_gray_to_color.apply(weights_init_normal)
netD_color.apply(weights_init_normal)
netD_gray.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [6]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [7]:
lr = 0.0001

optimizer_G = torch.optim.Adam(itertools.chain(netG_color_to_gray.parameters(), netG_gray_to_color.parameters()), 
                               lr=lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_color.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_gray.parameters(), lr=lr, betas=(0.5, 0.999))

In [8]:
batch_size = 32

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor
input_A = Tensor(batch_size, 3, 64, 64)
input_B = Tensor(batch_size, 3, 64, 64)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [9]:
transforms_ = T.Compose([T.Resize(int(64*1.12), Image.BICUBIC), 
                         T.RandomCrop(64), 
                         T.RandomHorizontalFlip(),
                         T.ToTensor(),
                         T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32
path = '../dataset/train_test' 
dataloader = DataLoader(MangaDataset(path, transforms_=transforms_), 
                        batch_size=batch_size, shuffle=True, num_workers=2)

In [10]:
# Training loop
epochs = 200  # Adjust the number of epochs as needed
for epoch in range(epochs):
    for i, batch in enumerate(dataloader):
        # Set model input
        real_A = batch[0].to(device)
        real_B = batch[1].to(device)

        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        same_B = netG_color_to_gray(real_B)
        loss_identity_B = criterion_identity(same_B, real_B) * 5.0
        
        same_A = netG_gray_to_color(real_A)
        loss_identity_A = criterion_identity(same_A, real_A) * 5.0

        # GAN loss
        fake_B = netG_color_to_gray(real_A)
        pred_fake = netD_gray(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

        fake_A = netG_gray_to_color(real_B)
        pred_fake = netD_color(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

        # Cycle loss
        recovered_A = netG_gray_to_color(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

        recovered_B = netG_color_to_gray(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        pred_real = netD_color(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = netD_color(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        pred_real = netD_gray(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
        
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD_gray(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()
        
    print(f'Epoch: {epoch}')

Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Epoch: 20
Epoch: 21
Epoch: 22
Epoch: 23
Epoch: 24
Epoch: 25
Epoch: 26
Epoch: 27
Epoch: 28
Epoch: 29
Epoch: 30
Epoch: 31
Epoch: 32
Epoch: 33
Epoch: 34
Epoch: 35
Epoch: 36
Epoch: 37
Epoch: 38
Epoch: 39
Epoch: 40
Epoch: 41
Epoch: 42
Epoch: 43
Epoch: 44
Epoch: 45
Epoch: 46
Epoch: 47
Epoch: 48
Epoch: 49
Epoch: 50
Epoch: 51
Epoch: 52
Epoch: 53
Epoch: 54
Epoch: 55
Epoch: 56
Epoch: 57
Epoch: 58
Epoch: 59
Epoch: 60
Epoch: 61
Epoch: 62
Epoch: 63
Epoch: 64
Epoch: 65
Epoch: 66
Epoch: 67
Epoch: 68
Epoch: 69
Epoch: 70
Epoch: 71
Epoch: 72
Epoch: 73
Epoch: 74
Epoch: 75
Epoch: 76
Epoch: 77
Epoch: 78
Epoch: 79
Epoch: 80
Epoch: 81
Epoch: 82
Epoch: 83
Epoch: 84
Epoch: 85
Epoch: 86
Epoch: 87
Epoch: 88
Epoch: 89
Epoch: 90
Epoch: 91
Epoch: 92
Epoch: 93
Epoch: 94
Epoch: 95
Epoch: 96
Epoch: 97
Epoch: 98
Epoch: 99
Epoch: 100

In [11]:
# Step 1: Prepare Input
input_image_path = "../dataset/train_test/test_gray/1.png"

# Step 2: Preprocess Input
preprocess = T.Compose([T.Resize(int(64*1.12), Image.BICUBIC), 
                         T.RandomCrop(64), 
                         T.RandomHorizontalFlip(),
                         T.ToTensor(),
                         T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

input_image = Image.open(input_image_path)
input_tensor = preprocess(input_image).unsqueeze(0)  # Add batch dimension

# Step 3: Forward Pass
# Assuming netG_gray_to_color is your generator model for converting grayscale to color
# Assuming netG_color_to_gray is your generator model for converting color to grayscale
# Pass input_tensor through the appropriate generator based on your desired transformation
# For example, if you want to generate a colored image from grayscale:
input_tensor = input_tensor.to(device)
output_tensor = netG_gray_to_color(input_tensor)
# If you want to generate a grayscale image from color:
# output_tensor = netG_color_to_gray(input_tensor)

# Step 4: Postprocess Output
output_tensor = output_tensor.squeeze(0)  # Remove batch dimension
output_image = T.ToPILImage()(output_tensor.detach().cpu())  # Convert tensor to PIL Image
output_image.show()  # Display the generated image

In [13]:
output_image.save("test_200.jpg")

In [14]:
import torchvision.transforms.functional as TF

# Convert the input tensor back to a PIL Image
input_pil_image = TF.to_pil_image(input_tensor.squeeze(0).detach().cpu())

# Save the input image
input_pil_image.save("input_image.png")

In [15]:
torch.save(netG_color_to_gray.state_dict(), 'model_gray_to_color.pth')
torch.save(netG_gray_to_color.state_dict(), 'model_color_to_gray.pth')
torch.save(netD_color.state_dict(), 'model_color_d.pth')
torch.save(netD_gray.state_dict(), 'model_gray_d.pth')