In [18]:
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torchvision import transforms as tf
import torch.utils.data as data

import os
import cv2
import functools
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [1]:
from models import vgg19

In [14]:
model = vgg19(pretrained=True).features[:-2]

model = model.eval()

In [15]:
model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [9]:
img = torch.rand(4,3,256,256)

In [10]:
out = model(img)
out.shape

torch.Size([4, 512, 8, 8])

In [19]:
class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, activation = 'lrelu', norm = 'in'):
        super(GatedConv2d, self).__init__()
        self.pad = nn.ZeroPad2d(padding)
        if norm is not None:
            self.norm = nn.InstanceNorm2d(out_channels)
        else:
            self.norm = None
            
        if activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            self.activation = nn.LeakyReLU(0.2, inplace = True)
        
       
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
        self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.pad(x)
        conv = self.conv2d(x)
        mask = self.mask_conv2d(x)
        gated_mask = self.sigmoid(mask)
        x = conv * gated_mask
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

class TransposeGatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, norm=None, scale_factor = 2):
        super(TransposeGatedConv2d, self).__init__()
        # Initialize the conv scheme
        self.scale_factor = scale_factor
        self.gated_conv2d = GatedConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, norm=norm)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')
        x = self.gated_conv2d(x)
        return x

In [20]:
class GatedGenerator(nn.Module):
    def __init__(self, in_channels=4, latent_channels=64, out_channels=3):
        super(GatedGenerator, self).__init__()
        self.coarse = nn.Sequential(
            # encoder
            GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),
            GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),
            GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),
            # Bottleneck
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            # decoder
            TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),
            GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),
            TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),
            GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)
        )
        self.refinement = nn.Sequential(
            # encoder
            GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),
            GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),
            GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),
            # Bottleneck
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
            # decoder
            TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),
            GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),
            TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),
            GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)
        )
        
    def forward(self, img, mask):
        # img: entire img
        # mask: 1 for mask region; 0 for unmask region
        # 1 - mask: unmask
        # img * (1 - mask): ground truth unmask region
        # Coarse
     
        first_masked_img = img * (1 - mask) + mask
        first_in = torch.cat((first_masked_img, mask), 1)       # in: [B, 4, H, W]
        first_out = self.coarse(first_in)                       # out: [B, 3, H, W]
        # Refinement
        second_masked_img = img * (1 - mask) + first_out * mask
        second_in = torch.cat((second_masked_img, mask), 1)     # in: [B, 4, H, W]
        second_out = self.refinement(second_in)                 # out: [B, 3, H, W]
        return first_out, second_out

In [21]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)

In [22]:
class PerceptualNet(nn.Module):
    def __init__(self):
        super(PerceptualNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, 3, 1, 1)
        )

    def forward(self, x):
        x = self.features(x)
        return x

In [6]:
class GANLoss(nn.Module):
    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

In [7]:
class InpaintDataset(data.Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.load_images()
        
    def load_images(self):
        self.fns =[]
        img_paths = sorted(os.listdir(self.img_dir))
        for path in img_paths:
            self.fns.append(os.path.join(self.img_dir, path))
            
    def __getitem__(self, index):
        img_path = self.fns[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256,256))
        
        mask = self.random_ff_mask()
        img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
        mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
        return img, mask
    
    def collate_fn(self, batch):
        imgs = torch.stack([i[0] for i in batch])
        masks = torch.stack([i[1] for i in batch])
        return {
            'imgs': imgs,
            'masks': masks
        }
    
    def __len__(self):
        return len(self.fns)
    
    def random_ff_mask(self, shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):
            """Generate a random free form mask with configuration.
            Args:
                config: Config should have configuration including IMG_SHAPES,
                    VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
            Returns:
                tuple: (top, left, height, width)
            """
            height = shape
            width = shape
            mask = np.zeros((height, width), np.float32)
            times = np.random.randint(times)
            for i in range(times):
                start_x = np.random.randint(width)
                start_y = np.random.randint(height)
                for j in range(1 + np.random.randint(5)):
                    angle = 0.01 + np.random.randint(max_angle)
                    if i % 2 == 0:
                        angle = 2 * 3.1415926 - angle
                    length = 10 + np.random.randint(max_len)
                    brush_w = 5 + np.random.randint(max_width)
                    end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                    end_y = (start_y + length * np.cos(angle)).astype(np.int32)
                    cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
                    start_x, start_y = end_x, end_y
            return mask.reshape((1, ) + mask.shape).astype(np.float32)

In [None]:
dataset = InpaintDataset(img_dir='datasets/places365standard_easyformat/places365_standard/train/waterfall')
dataloader = data.DataLoader(dataset, batch_size=4, collate_fn = dataset.collate_fn)

In [None]:
for batch in dataloader:
    imgs = batch['imgs']
    masks = batch['masks']
    
    break

In [8]:
device = torch.device('cuda')

In [23]:
model_G = GatedGenerator()
model_D = NLayerDiscriminator(3, use_sigmoid=True)
model_P = PerceptualNet()
criterion_adv = GANLoss()
criterion_rec = nn.MSELoss()
criterion_per = nn.L1Loss()
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=1e-4)
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=1e-4)

NameError: name 'GANLoss' is not defined

In [24]:
torch.save({
    'D': model_D.state_dict(),
    'G': model_G.state_dict()
}, 's.pth')

In [None]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
print(count_params(model_G))
print(count_params(model_D))
print(count_params(model_P))

In [10]:
def random_ff_mask(shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):
            """Generate a random free form mask with configuration.
            Args:
                config: Config should have configuration including IMG_SHAPES,
                    VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
            Returns:
                tuple: (top, left, height, width)
            """
            height = shape
            width = shape
            mask = np.zeros((height, width), np.float32)
            times = np.random.randint(times)
            for i in range(times):
                start_x = np.random.randint(width)
                start_y = np.random.randint(height)
                for j in range(1 + np.random.randint(5)):
                    angle = 0.01 + np.random.randint(max_angle)
                    if i % 2 == 0:
                        angle = 2 * 3.1415926 - angle
                    length = 10 + np.random.randint(max_len)
                    brush_w = 5 + np.random.randint(max_width)
                    end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                    end_y = (start_y + length * np.cos(angle)).astype(np.int32)
                    cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
                    start_x, start_y = end_x, end_y
            return mask.reshape((1, ) + mask.shape).astype(np.float32)

In [11]:
img = cv2.imread('datasets/places365standard_easyformat/places365_standard/train/waterfall/00000003.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
img_tensor = img.unsqueeze(0)
mask = random_ff_mask()
mask = torch.from_numpy(mask).contiguous().unsqueeze(0)

In [None]:
def visualize(img):
    np_img = img.squeeze(0).detach().cpu().numpy()
    return np_img.transpose(1, 2, 0)

In [None]:
plt.imshow(visualize(first_out_wholeimg))

In [12]:
first_out, second_out = model_G(img_tensor, mask)

first_out_wholeimg = img_tensor * (1 - mask) + first_out * mask     
second_out_wholeimg = img_tensor * (1 - mask) + second_out * mask

In [13]:
# Train discriminator
optimizer_D.zero_grad()

fake_D = model_D(second_out_wholeimg.detach())
real_D = model_D(img_tensor)

loss_fake_D = criterion_adv(fake_D, target_is_real=False)
loss_real_D = criterion_adv(real_D, target_is_real=True)

loss_D = (loss_fake_D + loss_real_D) *0.5

loss_D.backward()
optimizer_D.step()

In [15]:
# Train Generator

optimizer_G.zero_grad()

fake_D = model_D(second_out_wholeimg)
G_loss = criterion_adv(fake_D, target_is_real=True)

In [16]:
# Reconstruction loss

loss_rec_1 = criterion_rec(first_out_wholeimg, img_tensor)
loss_rec_2 = criterion_rec(second_out_wholeimg, img_tensor)

In [17]:
# Perceptual loss

img_featuremaps = model_P(img_tensor)                            # feature maps
second_out_wholeimg_featuremaps = model_P(second_out_wholeimg)

loss_P = criterion_per(second_out_wholeimg_featuremaps, img_featuremaps)

In [18]:
loss = lambda_G * G_loss + lambda_rec_1 * loss_rec_1 + lambda_rec_2 * loss_rec_2 + lambda_per * loss_P
loss.backward()
optimizer_G.step()

NameError: name 'lambda_G' is not defined