In [1]:
import os
print(os.getcwd())

C:\Users\papup\OneDrive\桌面


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.optim as optim
import torchvision
import math
from random import randint
from PIL import Image

from Image_Inpainting_Project.networks_test import Generator, Discriminator
from Image_Inpainting_Project.loss import ls_loss_d, ls_loss_g, hinge_loss_d, hinge_loss_g
from Image_Inpainting_Project.datasets import ImageDataset, random_bbox, bbox2mask, brush_stroke_mask
from datetime import datetime

In [3]:
train_data = ImageDataset()

In [4]:
print(f"Train size: {len(train_data)}")

Train size: 1224098


In [5]:
torch.cuda.is_available()

True

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device: {}'.format(device))
if torch.cuda.is_available():
    print('GPU Model: {}'.format(torch.cuda.get_device_name(0)))

Device: cuda:0
GPU Model: NVIDIA GeForce GTX 1660 Ti


In [7]:
batch_size = 4
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size, num_workers=8, pin_memory=True)

In [8]:
generator = Generator(4, 64)
discriminator = Discriminator(4, 64)
generator = generator.to(device)
discriminator = discriminator.to(device)

In [9]:
optimizer_g = optim.Adam(generator.parameters(), betas=([0.5, 0.999]), lr=0.0001)
optimizer_d = optim.Adam(discriminator.parameters(), betas=([0.5, 0.999]), lr=0.0001)

In [10]:
g_loss_log = []
d_loss_log = []

In [11]:
def train(iterations, generator, discriminator, optimizer_g, optimizer_d, gan_loss_g, gan_loss_d, iter_start):
    losses = {}
    generator.train()
    discriminator.train()
    torch.backends.cudnn.benchmark = True
    iter_count = iter_start
    train_iter = iter(train_loader)
    start_time = datetime.now()
    for iters in range(iterations):
        try:
            image = next(train_iter)
        except:
            train_iter = iter(train_loader)
            image = next(train_iter)
        iter_count += 1
        image = image.to(device=device, dtype=torch.float)
            
        # Create random mask
        bbox = random_bbox()
        regular_mask = bbox2mask(bbox).to(device)
        irregular_mask = brush_stroke_mask().to(device)
        mask = torch.logical_or(irregular_mask, regular_mask).to(torch.float32)

        incomplete_img = image * (1.0 - mask)
        ones_x = torch.ones_like(incomplete_img)[:, 0:1, :, :].to(device)
        x = torch.cat([incomplete_img, ones_x * mask], axis=1)
            
        coarse_img, fine_img = generator(x, mask)
        complete_img = fine_img * mask + incomplete_img * (1.0 - mask)
            
        img_save = complete_img.cpu().detach().numpy()
        mask_save = mask.squeeze().cpu().detach().numpy()
            
        real_mask = torch.cat((image, torch.tile(mask, [batch_size, 1, 1, 1])), dim=1)
        filled_mask = torch.cat((complete_img.detach(), torch.tile(mask, [batch_size, 1, 1, 1])), dim=1)
            
        real_filled = torch.cat((real_mask, filled_mask))
            
        # Discriminator training steps
        d_real_gen = discriminator(real_filled)
        d_real, d_gen = torch.split(d_real_gen, batch_size)
            
        d_loss = gan_loss_d(d_real, d_gen)
        
        losses['d_loss'] = d_loss
            
        # Update discriminator parameters
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()
            
        # Generator training steps
        losses['ae_loss1'] = torch.mean((torch.abs(image - coarse_img)))
        losses['ae_loss2'] = torch.mean((torch.abs(image - fine_img)))
        losses['ae_loss'] = losses['ae_loss1'] + losses['ae_loss2']
            
        gen_img = torch.cat((fine_img, torch.tile(mask, [batch_size, 1, 1, 1])), dim=1)
            
        d_gen = discriminator(gen_img)
            
        g_loss = gan_loss_g(d_gen)
        
        losses['g_loss'] = g_loss
        losses['g_loss'] += losses['ae_loss']
            
        # Update generator parameters
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()
            
        if iter_count % 1000 == 0:
            g_loss_log.append(losses['g_loss'].cpu().detach().numpy())
            d_loss_log.append(losses['d_loss'].cpu().detach().numpy())
                
        if iter_count % 5000 == 0:
            for idx, result in enumerate(img_save):
                inpainted_img = ((result + 1) * 127.5).astype(np.uint8).transpose(1, 2, 0)
                inpainted_img = Image.fromarray(inpainted_img).convert('RGB')
                inpainted_img.save(f"D:/inpainting_result_test/iter_{iter_count}_{idx+1}.jpg")
                    
        if iter_count % 20000 == 0:
            torch.save({'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_g_state_dict': optimizer_g.state_dict(),
                'optimizer_d_state_dict': optimizer_d.state_dict(),
            }, f'D:inpainting_checkpoints_test/Inpainting_model_state_dict_iter_{iter_count}.pt')
            print(f"Iteration: {iter_count}, time for {iters+1} iterations: {datetime.now() - start_time}")
        
    print('Finished Training')
    plt.plot(g_loss_log, 'r')
    plt.plot(d_loss_log, 'b')
    plt.show()

In [None]:
iterations = 300000
iter_count_resume = 0
print(iter_count_resume)
train(iterations, generator, discriminator, optimizer_g, optimizer_d, hinge_loss_g, hinge_loss_d, iter_start=iter_count_resume)

0
Iteration: 20000, time for 20000 iterations: 2:16:21.937091
Iteration: 40000, time for 40000 iterations: 4:32:31.105637
Iteration: 60000, time for 60000 iterations: 6:48:42.331068
Iteration: 80000, time for 80000 iterations: 9:05:02.456299
Iteration: 100000, time for 100000 iterations: 11:21:15.968076
Iteration: 120000, time for 120000 iterations: 13:37:19.310976
Iteration: 140000, time for 140000 iterations: 15:54:29.493534
Iteration: 160000, time for 160000 iterations: 18:14:04.024214
Iteration: 180000, time for 180000 iterations: 20:34:25.994496
Iteration: 200000, time for 200000 iterations: 23:55:36.278773


In [None]:
model_check_point = 'D:inpainting_checkpoints_test/Inpainting_model_state_dict_iter_1000000.pt'
checkpoint = torch.load(model_check_point)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])

In [None]:
for params in discriminator.parameters():
    print(params)