In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from PIL import Image
import warnings
warnings.simplefilter("ignore", UserWarning)
from collections import namedtuple
import time
import os
from dzlib.utils.helper import info, stats, npshow, params, janimate, ccrop_pil
from dzlib.nn_models.unet import UNet

print(f"Imports Complete")

Imports Complete


In [3]:
# General Settings
dtype = torch.FloatTensor
%matplotlib notebook
matplotlib.rcParams['savefig.dpi'] = 80
matplotlib.rcParams['figure.dpi'] = 80
# %config InlineBackend.print_figure_kwargs={'bbox_inches':None}


In [5]:
# image_pil (original)
data_dir = os.getcwd() + "/data"
fn = data_dir + "/dog.jpg"
image_pil = Image.open(fn)
width, height = image_pil.size
print(image_pil.size)

# Center Crop dims to be divisible by 32
image_pil_HR = ccrop_pil(image_pil, factor=32)
print(image_pil_HR.size)



# image_pil_LR (Resize by factor of 4)
factor = 4
LR_shape = (int(width / factor), int(height / factor))
image_pil_LR = image_pil_HR.resize(LR_shape, Image.ANTIALIAS)
width, height = image_pil_LR.size
print(image_pil_LR.size)
    
print(type(image_pil))

(960, 609)
(960, 608)
(240, 152)
<class 'PIL.JpegImagePlugin.JpegImageFile'>


In [None]:
# image_np_HR (np > float32 > (C x H x W) > (0...1))
image_np_HR = np.asarray(image_pil_HR).astype(np.int32).transpose(2, 0, 1) / 255
info(image_np_HR, 'image_np_HR')
stats(image_np_HR, 'image_np_HR')

image_pt_HR = torch.from_numpy(image_np_HR).type(dtype)
info(image_pt_HR, 'image_pt_HR')
stats(image_pt_HR, 'image_pt_HR')

# image_np_LR (np > float32 > (C x H x W) > (0...1))
image_np_LR = np.asarray(image_pil_LR).astype(np.int32).transpose(2, 0, 1) / 255
info(image_np_LR, 'image_np_LR')
stats(image_np_LR, 'image_np_LR')

image_pt_LR = torch.from_numpy(image_np_LR).type(dtype)
info(image_pt_LR, 'image_pt_LR')
stats(image_pt_LR, 'image_pt_LR')


fig, ax = plt.subplots(1, figsize=(14, 8), ncols=2)
npshow(image_np_HR, ax[0])
npshow(image_np_LR, ax[1])



In [None]:
ninput = torch.zeros(1, 32, image_pt_HR.shape[1], image_pt_HR.shape[2])
std_scaler = 1./10.
reg_noise_std = 0.03
ninput = ninput.uniform_() * std_scaler
ninput = ninput.type(dtype)
ninput_saved = ninput.detach().clone()
noise = ninput.detach().clone()

target = image_pt_LR.view(1, *image_pt_LR.shape).type(dtype)

info(net_input)
stats(net_input)

info(target)
stats(target)

In [None]:
print(UNet.__init__.__doc__)
print(UNet.activations)
print(UNet.Conv._fields)
print(UNet.Upsample._fields)

In [None]:
pad = 'reflect'

in_channels = ninput.shape[1]
out_channels = target.shape[1]

down_channels = [128, 128, 128, 128, 128]
skip_channels = [4, 4, 4, 4, 4]
up_channels = [128, 128, 128, 128, 128]

down_conv1 = UNet.Conv(3, 2, pad)
down_conv2 = UNet.Conv(3, 1, pad)
down_convs = [down_conv1, down_conv2]

skip_conv1 = UNet.Conv(1, 1, pad)
skip_convs = [skip_conv1]

up_conv1 = UNet.Conv(3, 1, pad)
up_conv2 = UNet.Conv(1, 1, pad)
up_convs = [up_conv1, up_conv2]

batchnorm = True
last_batchnorm = False

activation = 'leakyrelu'
last_activation = 'sigmoid'

upsample = UNet.Upsample(size=None, scale_factor=2, mode='bilinear', align_corners=None)

net = UNet(in_channels, out_channels, down_channels, skip_channels, up_channels, \
           down_convs, skip_convs, up_convs, batchnorm, last_batchnorm, activation, last_activation, upsample)
net = net.type(dtype)

n_params = params(net)
print(n_params)

In [None]:
n_iters = 2000
cwd = os.getcwd()
save_path = cwd+'/outputs' + '/run3'
os.mkdir(save_path)

digits = len(str(n_iters))
filenames = [f"{save_path}/{i+1:0{digits}d}.png" for i in range(n_iters)]

print(len(filenames))


In [None]:
fig = plt.figure(figsize=(12, 16))
gs = GridSpec(nrows=6, ncols=1, figure=fig)

ax1 = fig.add_subplot(gs[:2, :])
ax1.set_xlim(0, image_np_HR.shape[2]*2)
ax1.set_ylim(image_np_HR.shape[1], 0)
kwargs1 = None

ax2 = fig.add_subplot(gs[2:4, :])
ax2.set_xlim(0, image_np_LR.shape[2]*2)
ax2.set_ylim(image_np_LR.shape[1], 0)
kwargs1 = None

ax3 = fig.add_subplot(gs[4:, :])
kwargs3 = {'color': 'k'}



In [None]:
lr = 0.01
mse = nn.MSELoss().type(dtype)
optimizer = optim.Adam(net.parameters(), lr=lr)
losses = []
remove = True

train_time = time.time()
for i in range(n_iters):
    iter_time = time.time()
    
    # zero grad
    optimizer.zero_grad()
    
    # reg
    ninput = ninput_saved + noise.normal_() * reg_noise_std
    
    # forward pass
    out_HR = net(ninput)
    
    # downsample
    out_LR = F.interpolate(input=out_HR, size=(target.shape[2], target.shape[3]), mode='bilinear')
    
    # evaluate loss
    loss = mse(out_LR, target)
    
    # back prop
    loss.backward()
    
    # step
    optimizer.step()
