## Prepare

In [1]:
import numpy as np 
import pandas as pd 
import os
import cv2
from tqdm.auto import tqdm

In [2]:
# Data description: https://www.kaggle.com/competitions/global-wheat-detection/data
df = pd.read_csv('E:/global-wheat-detection/train.csv')
bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))
for i, column in enumerate(['x', 'y', 'w', 'h']):
    df[column] = bboxs[:,i]
df.drop(columns=['bbox'], inplace=True)
df = df[['image_id','x', 'y', 'w', 'h']]
index = list(set(df.image_id))

In [6]:
!git clone https://github.com/eriklindernoren/PyTorch-GAN/

Cloning into 'PyTorch-GAN'...


In [3]:
%cd PyTorch-GAN/implementations/pix2pix

C:\Users\Moon\Colorization\PyTorch-GAN\implementations\pix2pix


In [4]:
!ls

Untitled.ipynb
__pycache__
datasets.py
models.py
pix2pix.py


## Config

In [5]:
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

# pix2pix
from models import *
from datasets import *

import torch.nn as nn
import torch.nn.functional as F
import torch

In [6]:
# Config
class opt:
    epoch = 0
    n_epochs = 50  # change to 50 for train
    dataset_name = 'test1'
    batch_size = 8
    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    decay_epoch = 100
    n_cpu = 1
    img_height = 256
    img_width = 256
    channels = 3
    sample_interval = 100
    checkpoint_interval = 338

In [7]:
os.makedirs("E:/Colorization/pix2pix/images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("E:/Colorization/pix2pix/saved_models/%s" % opt.dataset_name, exist_ok=True)

In [8]:
cuda = True if torch.cuda.is_available() else False

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

if opt.epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("E:/Colorization/pix2pix/saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
    discriminator.load_state_dict(torch.load("E:/Colorization/pix2pix/saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

In [9]:
generator

GeneratorUNet(
  (down1): UNetDown(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
  )
  (down2): UNetDown(
    (model): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down3): UNetDown(
    (model): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (down4): UNetDown(
    (model): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=Fals

In [10]:
# Optimizer
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

## Dataset

In [11]:
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt


class ImageDataset_color(Dataset):
    
    def __init__(self, root, transforms_=None, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        
    def __getitem__(self, index):
        '''img_A : RGB color image (1024, 1024, 3)
           img_B : Gray image (1024, 1024, 3) - not 1 channel'''
        img_A = cv2.imread(self.files[index % len(self.files)])
        img_A = cv2.cvtColor(img_A, cv2.COLOR_BGR2RGB)
        
        img_B = cv2.cvtColor(cv2.cvtColor(img_A, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
        
        img_A = Image.fromarray(np.array(img_A), "RGB")
        img_B = Image.fromarray(np.array(img_B), "RGB")
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return len(self.files)
    
    
class ImageDataset_edge(Dataset):
    
    def __init__(self, root, transforms_=None, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        
    def __getitem__(self, index):
        '''img_A : RGB color image (1024, 1024, 3)
           img_B : Canny edge image (1024, 1024, 3)'''
        img_A = cv2.imread(self.files[index % len(self.files)])
        
        gray = cv2.cvtColor(img_A, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        
        img_B = cv2.Canny(blurred, 50, 150)  # Canny edge
        img_B = cv2.cvtColor(img_B, cv2.COLOR_GRAY2RGB)
        
        img_A = cv2.cvtColor(img_A, cv2.COLOR_BGR2RGB)
        
        img_A = Image.fromarray(np.array(img_A), "RGB")
        img_B = Image.fromarray(np.array(img_B), "RGB")
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return len(self.files)    

In [12]:
# Configure dataloaders
transforms_ = [
    transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    ImageDataset_color("E:/global-wheat-detection/", transforms_=transforms_),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=0,
)

val_dataloader = DataLoader(
    ImageDataset_color("E:/global-wheat-detection/", transforms_=transforms_, mode="test"),
    batch_size=10,
    shuffle=True,
    num_workers=0,
)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [13]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
Tensor

torch.cuda.FloatTensor

In [39]:
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    
    real_A = Variable(imgs["B"].type(Tensor))  # Gray image (1024, 1024, 3)
    real_B = Variable(imgs["A"].type(Tensor))  # Color image (1024, 1024, 3)
    
    fake_B = generator(real_A)  # Generated color image
    
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2).cpu().numpy().astype(np.float32)
    img_sample -= img_sample.min()
    img_sample/= img_sample.max()
    img_sample = img_sample.transpose(0, 2, 3, 1)
    plt.figure(figsize=[10, 20])
    for row in range(3):
        plt.subplot(1, 3, row+1)
        plt.imshow(img_sample[row])
    plt.show()
    

## Train

In [None]:
prev_time = time.time()

for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = Variable(batch["B"].type(Tensor))  # Gray image (1024, 1024, 3)
        real_B = Variable(batch["A"].type(Tensor))  # Color image (1024, 1024, 3)

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)
        
        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()
        
        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = opt.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        
        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
                time_left,
            )
        )
        
        # If at sample interval save image
        if batches_done % opt.sample_interval == 0:
            sample_images(batches_done)

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), "E:/Colorization/pix2pix/saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
        torch.save(discriminator.state_dict(), "E:/Colorization/pix2pix/saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))