In [None]:
import os
import time
import torch
import torch as nn
import numpy as np
import torchvision
from PIL import Image
from models import *
from combined_model import *
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

sketch_dir = F'/content/drive/My Drive/CSCI566/Data/simpler_sketch'
target_dir = F'/content/drive/My Drive/CSCI566/Data/target_10w'
model_dir = F'/content/drive/My Drive/CSCI566/Model_state'

In [None]:
def imshow_sketch(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    npimg = np.clip(npimg, 0, 1)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def imshow_rgb(img):
    npimg = img.numpy().transpose((1, 2, 0))
    mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
    npimg = npimg * std + mean
    npimg = np.clip(npimg, 0, 1)
    plt.imshow(npimg)
    plt.show()

class AnimeSketchDataset(Dataset):
    """Anime Sketches dataset converted from the original pictures"""
    
    def __init__(self, sketch_dir, target_dir, transform=None):
        
        """
        Args:
            img_names (string): a path to a txt file of sketches file names seperated by comma.
            sketch_dir (string): directory with all converted sketches.
            target_dir (string): directory with all target images.
            transform (callable, optional): optional transform to be applied on a sample.
        """
        
        # self.names = os.listdir(sketch_dir)
        self.names = [str(num) + '.jpg' for num in range(0, 100000)] 
        self.sketch_dir = sketch_dir
        self.target_dir = target_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        
        target = os.path.join(self.target_dir, self.names[idx])
        sketch = os.path.join(self.sketch_dir, self.names[idx])
        s = Image.open(sketch).convert('L')
        t = Image.open(target)
        
        if self.transform:
            s = self.transform['sketch'](s)
            t = self.transform['rgb'](t)
            
        return {'sketch': s, 'target': t, 'image_name': self.names[idx]}

transform = {'sketch': transforms.Compose([transforms.ToTensor(),
                                           transforms.Normalize((0.5), (0.5))]),
             'rgb': transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            }

######################################## tune #########################################
params = {
    'CAE_norm': 'instance',
    'FD_norm': 'instance',
    'FD_output_channels': 1,
    'G_input_nc': 1,
    'ngf': 4,
    'G_norm': 'instance',
    'G_n_downsampling': 3, # if changed, modification needed in networks.py
    'G_n_blocks': 1,
    'D_n_downsampling': 3, # if changed, modification needed in networks.py
    'D_n_blocks': 0,
    'parts': {
        'face': {
            'cae_weights': os.path.join(model_dir, 'ae_model_ss_30ep.pt')
        }
    }

}

batch_size = 10
max_epoch = 1
learning_rate = 2e-4
betas = (0.5, 0.999)
#######################################################################################

In [None]:
dataset = AnimeSketchDataset(sketch_dir, target_dir, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
model = Combined_Model(params).cuda()
discriminator = netD(params).cuda()

crit_m = nn.MSELoss() # subject to change
crit_d = nn.BCEWithLogitsLoss()

model_params = list(model.part_feature_decoder['face'].parameters()) + list(model.G.parameters())
optimizer_m = torch.optim.Adam(model_params, lr=learning_rate, betas=betas)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=betas)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model.train()
model.part_encoder['face'].eval()
discriminator.train()

fake_label = torch.zeros([batch_size, 1], device=device)
real_label = torch.ones([batch_size, 1], device=device)

for epoch in range(max_epoch):
  running_m_loss = 0.0
  running_d_loss = 0.0
  start = time.time()

  for i, data in enumerate(dataloader):
    sketch_batch = data['sketch'].cuda()
    target_batch = data['target'].cuda()

    ### Discriminator ###
    for p in discriminator.parameters(): p.requires_grad = True
    # with real
    optimizer_d.zero_grad()
    real_dis_out, hints = discriminator(target_batch)
    real_dis_loss = crit_d(real_dis_out, real_label)
    real_dis_loss.backward()

    # with fake
    fake_batch, hints = model(sketch_batch, target_batch, hints)
    fake_dis_out, _ = discriminator(fake_batch.detach(), hints.detach())
    fake_dis_loss = crit_d(fake_dis_out, fake_label)
    fake_dis_loss.backward()
    optimizer_d.step()

    ### Model (Feature Decoder + Generator) ###
    for p in discriminator.parameters(): p.requires_grad = False

    optimizer_m.zero_grad()
    fake_dis_out, _ = discriminator(fake_batch, hints)
    gan_loss = crit_d(fake_dis_out, real_label)
    content_loss = crit_m(fake_batch, target_batch)
    total_loss = gan_loss + content_loss
    total_loss.backward()
    optimizer_m.step()

    running_m_loss += total_loss.item()
    running_d_loss += real_dis_loss.item() + fake_dis_loss.item()
    if i % 1000 == 999:
      end = time.time()
      print(f'[{epoch+1},{i+1}] Model loss: {running_m_loss/1000}, Dis loss: {running_d_loss/1000} --- {(end-start)/60} min')
      running_m_loss = 0.0
      running_d_loss = 0.0
      start = time.time()

In [None]:
# torch.save(model.state_dict(), os.path.join(model_dir, 'm_1ep.pt'))
# torch.save(discriminator.state_dict(), os.path.join(model_dir, 'd_1ep.pt'))
# torch.save(optimizer_m.state_dict(), os.path.join(model_dir, 'om_1ep.pt'))
# torch.save(optimizer_d.state_dict(), os.path.join(model_dir, 'od_1ep.pt'))

In [None]:
torch.save({'model': model.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'optimizer_d': optimizer_d.state_dict()
            }, os.path.join(model_dir, 'full_model_1ep.pt'))

In [None]:
# checkpoint = torch.load(os.path.join(model_dir, 'full_model_1ep.pt'), map_location=device)
# model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [None]:
test_sketch=F'/content/drive/My Drive/CSCI566/Data/temp/323.jpg'
test_target=F'/content/drive/My Drive/CSCI566/Data/target/323.jpg'

model.eval()

ske = Image.open(test_sketch).convert('L')
tar = Image.open(test_target)
ske = transform['sketch'](ske)
t = transform['rgb'](tar)
imshow_rgb(torchvision.utils.make_grid(t))
ske, t = ske.unsqueeze(0), t.unsqueeze(0)

with torch.no_grad():
   
  out_img, _ = model(ske.cuda(), t.cuda())
  print(out_img.shape)
  imshow_rgb(torchvision.utils.make_grid(out_img[0].cpu()))