In [2]:
from core.data.data import Dataset
from core.loss.loss import TotalLoss
from core.network.smoothingnetwork import SmoothingNet
from core.data.makeedge import make_edge_files
from torch.optim import Adam
from torch.utils import data
from torch import autograd
from torch import nn
import torch
from core.loss.dataloss import DataLoss
from core.loss.smoothnessloss import SmoothnessLoss
from core.loss.edgepreservingloss import EdgePreservingLoss

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torchvision

#make_edge_files("data/eval", "data/eval_edges", os.getcwd(), high=255, low=150)

def show_generated_images(dataset, net, device,show_n=5):
    with torch.no_grad():
        image_idx = np.random.choice(len(dataset), show_n)
        image_idx
        images = []
        for idx in image_idx:
            images.append(dataset[idx][0])

        def show(img,ax):
            npimg = img.cpu().numpy()
            ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
        
        
        fig, axs = plt.subplots(2, figsize = (20,10))
        images = torch.stack(images).to(device)
        output_images = net(images).clamp(0.0,1.0)
        
        show(torchvision.utils.make_grid(images, padding=50), axs[0])
        show(torchvision.utils.make_grid(output_images, padding=50), axs[1])


        plt.show()
        
def train(net, epochs, batch_size, optimizer, dataset, window_size=5,device='cuda'):
    dataloader = data.DataLoader(dataset, batch_size=batch_size)
    net.to(device)
    net.train()
    
    # criteria = TotalLoss().to(device)
    data_loss = DataLoss()
    smoothness_loss = SmoothnessLoss(window_size=window_size)
    edge_preserving_loss = EdgePreservingLoss()
    w_data = 1.0
    w_smooth = 1.0
    w_edge_preserving = 0.1
    n = 0
    for e in range(epochs):
        print("Epoch: {}".format(e+1))
        for i, (images, binary_mask) in enumerate(dataloader):
            n += 1
            optimizer.zero_grad()
            images = images.to(device)
            binary_mask = binary_mask.to(device)
            
            smooth_images_residual = net(images)
            #print(torch.mean(smooth_images_residual))
            #print(torch.max(smooth_images_residual))
            #print(torch.max(images))
            smooth_images = smooth_images_residual# + images

            #with autograd.detect_anomaly():
            #ts = time.time()
            D = w_data * data_loss(images, smooth_images)  
            S = w_smooth * smoothness_loss(images, smooth_images)  
            E = w_edge_preserving * edge_preserving_loss(binary_mask, images, smooth_images)
            loss = D + S + E
            #te = time.time()
            #print("{}: {}".format("Total Loss", (te - ts) * 1000))

            #ts = time.time()
            loss.backward()
            #te = time.time()
            #print("{}: {}".format("Backprop", (te - ts) * 1000))

            #ts = time.time()
            optimizer.step()
            #te = time.time()
            #print("{}: {}".format("Step", (te - ts) * 1000))
            if i % 100 == 0:
                print("\tN: {}".format(n))
                show_generated_images(dataset=dataset, net=net, device=device)
                print("\tD: {:.5f}".format(D.item()), end=' ')
                print("\tS: {:.5f}".format(S.item()), end=' ')
                print("\tE: {:.5f}".format(E.item()), end=' ')
                print("\tTotal: {}".format(loss))
                torch.save({'state_dict':net.state_dict(),
                                'epoch':e}, ('model.pth'))
                
                
net = SmoothingNet()

epochs = 1
batch_size = 1
data_dir = "data"
dataset = Dataset(os.path.join(data_dir, "train2014"), os.path.join(data_dir, "edges"), edge_prefix='')

optimizer = Adam(net.parameters(), lr=1e-2)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight)

#net.apply(init_weights)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

%matplotlib inline
train(net, epochs, batch_size, optimizer, dataset)