# Image colorization using CNN - Fruits Dataset

Import packages

In [None]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# For conversion
from skimage.color import lab2rgb, rgb2lab, rgb2gray
# For everything
import torch
import torch.nn as nn
# For our model
from torchvision import datasets, transforms
# For utilities
import patoolib

In [1]:
tr = 'images/train/'
te = 'images/val/'

If you need to extract the images.

In [3]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()

### Function used

In [3]:
# Class BaseColor from Zhang's github repository
class BaseColor(nn.Module):
    def __init__(self):
        super(BaseColor, self).__init__()

        self.l_cent = 50.
        self.l_norm = 100.
        self.ab_norm = 110.

    def normalize_l(self, in_l):
        return (in_l-self.l_cent)/self.l_norm

    def unnormalize_l(self, in_l):
        return in_l*self.l_norm + self.l_cent

    def normalize_ab(self, in_ab):
        return in_ab/self.ab_norm

    def unnormalize_ab(self, in_ab):
        return in_ab*self.ab_norm

In [73]:
class Colorization(BaseColor):
    def __init__(self, norm_layer=nn.BatchNorm2d):
        super(ECCVGenerator, self).__init__()

        model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[norm_layer(64),]

        model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[norm_layer(128),]

        model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[norm_layer(256),]

        model4 =[nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0, bias=True),]

        self.model1 = nn.Sequential(*model1)
        self.model2 = nn.Sequential(*model2)
        self.model3 = nn.Sequential(*model3)
        self.model4 = nn.Sequential(*model4)

        self.softmax = nn.Softmax(dim=1)
        self.model_out = nn.Conv2d(256, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
        self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, input_l):
        conv1 = self.model1(self.normalize_l(input_l))
        conv2 = self.model2(conv1)
        conv3 = self.model3(conv2)
        conv4 = self.model4(conv3)
        out_reg = self.model_out(self.softmax(conv4))

        return self.unnormalize_ab(self.upsample4(out_reg))

In [9]:
# Class used to transform images for the network
class GrayscaleImageFolder(datasets.ImageFolder):
  '''Custom images folder, which converts images to grayscale before loading'''
  def __getitem__(self, index):
    path, target = self.imgs[index]
    img = self.loader(path)
    if self.transform is not None:
      img_original = self.transform(img)
      img_original = np.asarray(img_original)
      img_lab = rgb2lab(img_original)
      img_lab = (img_lab + 128) / 255
      img_ab = img_lab[:, :, 1:3]
      img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
      img_original = rgb2gray(img_original)
      img_original = torch.from_numpy(img_original).unsqueeze(0).float()
    if self.target_transform is not None:
      target = self.target_transform(target)
    return img_original, img_ab, target

In [8]:
# Function used to convert images from LAB color space to RGB in order to visualize the results
def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
  '''Show/save rgb image from grayscale and ab channels
     Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
  plt.clf() # clear matplotlib 
  color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
  color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
  color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
  color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
  color_image = lab2rgb(color_image.astype(np.float64))
  grayscale_input = grayscale_input.squeeze().numpy()
  if save_path is not None and save_name is not None: 
    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
    plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

In [9]:
def evaluate(model, val_loader, criterion, save_images, epoch):
  epoch_loss = 0

  # Evaluation mode
  model.eval()
  already_saved_images = False
    
  # Do not compute gradients
  with torch.no_grad():

    for i, (input_gray, input_ab, target) in enumerate(val_loader):
      if use_gpu: input_gray, input_ab, target = input_gray.to('cuda'), input_ab.to('cuda'), target.to('cuda')
      # Make Predictions
      output_ab = model(input_gray) # throw away class predictions
      loss = criterion(output_ab, input_ab)
      # Extract data from loss and accuracy
      epoch_loss += loss.item()
    
      # Save images to file
      if save_images and not already_saved_images:
        already_saved_images = True
        for j in range(len(output_ab)):
          save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'}
          save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)
          to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)

  return epoch_loss/len(val_loader)

In [10]:
def train(model, train_loader, optimizer, criterion):
  epoch_loss = 0

  # Train mode
  model.train()

  for i, (input_gray, input_ab, target) in enumerate(train_loader):
    if use_gpu: input_gray, input_ab, target = input_gray.to('cuda'), input_ab.to('cuda'), target.to('cuda')
    # Set gradients to zero
    
    # Make Predictions
    output_ab = model(input_gray)
    loss = criterion(output_ab, input_ab)
    # Backprop
    loss.backward()

    # Apply optimizer
    optimizer.step()
    optimizer.zero_grad()

    # Extract data from loss and accuracy
    epoch_loss += loss.item()
   
  return epoch_loss/len(train_loader)

In [11]:
def model_training(n_epochs, model, train_loader, val_loader, optimizer, criterion, save_images, model_name):

  # Initialize validation loss
  best_valid_loss = float('inf')

  # Save output losses, accs
  train_losses = []
  valid_losses = []
  

  # Loop over epochs
  for epoch in range(n_epochs):
    start_time = time.time()
    # Train
    train_loss = train(model, train_loader, optimizer, criterion)
    # Validation
    valid_loss = evaluate(model, val_loader, criterion, save_images, epoch)
    # Save best model
    if valid_loss < best_valid_loss:
      best_valid_loss = valid_loss
      # Save model
      torch.save(model.state_dict(), model_name)
      
    end_time = time.time()
    
    print(f"\nEpoch: {epoch+1}/{n_epochs} -- Epoch Time: {end_time-start_time:.2f} s")
    print("---------------------------------")
    print(f"Train -- Loss: {train_loss:.3f}")
    print(f"Val -- Loss: {valid_loss:.3f}")

    # Save
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

  return train_losses, valid_losses

In [12]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Training

In [None]:
# Training
train_transforms = transforms.Compose([transforms.Resize((80,80)),
                                       transforms.RandomHorizontalFlip()])
train_imagefolder = GrayscaleImageFolder(tr, train_transforms)
train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=64, shuffle=True)

# Validation 
val_transforms = transforms.Compose([transforms.Resize((80,80))])
val_imagefolder = GrayscaleImageFolder(te, val_transforms)
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=64, shuffle=False)

In [11]:
# Make folders and set parameters
os.makedirs('./outputs/color', exist_ok=True)
os.makedirs('./outputs/gray', exist_ok=True)

save_images = True
best_losses = 1e10

In [12]:
#del model

model = Colorization()
criterion = nn.MSELoss()
#criterion = nn.SmoothL1Loss(beta = 0.75)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)

Un-comment if you want to load a saved model.

In [13]:
# pretrained = torch.load('./models/anim140.pth', map_location=lambda storage, loc: storage)
# model.load_state_dict(pretrained)

In [14]:
print(f"The model has {count_parameters(model):,} trainable parameters.")

The model has 3,572,288 trainable parameters.


In [15]:
# Move model and loss function to GPU
if use_gpu: 
  criterion = criterion.to('cuda')
  model = model.to('cuda')

In [27]:
for i in range(1,15):
    N_EPOCHS = 10
    model_name = f"anim{N_EPOCHS*i}.pth"
    train_losses, valid_losses = model_training(N_EPOCHS, 
                                                model, 
                                                train_loader, 
                                                val_loader, 
                                                optimizer, 
                                                criterion,
                                                save_images,
                                                model_name)

  "See the documentation of nn.Upsample for details.".format(mode)



Epoch: 1/10 -- Epoch Time: 71.62 s
---------------------------------
Train -- Loss: 0.036
Val -- Loss: 0.003

Epoch: 2/10 -- Epoch Time: 72.25 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 3/10 -- Epoch Time: 72.51 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 4/10 -- Epoch Time: 75.35 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 5/10 -- Epoch Time: 76.81 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 6/10 -- Epoch Time: 76.51 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.007

Epoch: 7/10 -- Epoch Time: 74.86 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 8/10 -- Epoch Time: 74.83 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 9/10 -- Epoch Time: 75.37 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.062

Epoch: 10

  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 5/10 -- Epoch Time: 76.28 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.006


  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 6/10 -- Epoch Time: 77.45 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.003

Epoch: 7/10 -- Epoch Time: 74.26 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005


  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 8/10 -- Epoch Time: 74.63 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.008

Epoch: 9/10 -- Epoch Time: 74.20 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.007

Epoch: 10/10 -- Epoch Time: 78.56 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 1/10 -- Epoch Time: 75.59 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.008

Epoch: 2/10 -- Epoch Time: 75.89 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.008

Epoch: 3/10 -- Epoch Time: 76.25 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.004

Epoch: 4/10 -- Epoch Time: 78.62 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 5/10 -- Epoch Time: 74.25 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.012


  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 6/10 -- Epoch Time: 73.78 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.005

Epoch: 7/10 -- Epoch Time: 73.81 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.006

Epoch: 8/10 -- Epoch Time: 74.03 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 9/10 -- Epoch Time: 78.51 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 10/10 -- Epoch Time: 82.91 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 1/10 -- Epoch Time: 85.06 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.004

Epoch: 2/10 -- Epoch Time: 83.96 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.004

Epoch: 3/10 -- Epoch Time: 85.41 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 4/10 -- Epoch Time: 83.86 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 5

  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 8/10 -- Epoch Time: 85.40 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.006

Epoch: 9/10 -- Epoch Time: 85.65 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.003

Epoch: 10/10 -- Epoch Time: 84.69 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 1/10 -- Epoch Time: 84.89 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 2/10 -- Epoch Time: 84.26 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 3/10 -- Epoch Time: 77.42 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.004

Epoch: 4/10 -- Epoch Time: 73.25 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.013

Epoch: 5/10 -- Epoch Time: 78.13 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.005

Epoch: 6/10 -- Epoch Time: 76.06 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003


  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 7/10 -- Epoch Time: 76.13 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 8/10 -- Epoch Time: 75.67 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.003

Epoch: 9/10 -- Epoch Time: 76.15 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.007

Epoch: 10/10 -- Epoch Time: 74.17 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003


  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 1/10 -- Epoch Time: 70.85 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 2/10 -- Epoch Time: 73.32 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 3/10 -- Epoch Time: 69.77 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 4/10 -- Epoch Time: 70.64 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 5/10 -- Epoch Time: 70.10 s
---------------------------------
Train -- Loss: 0.004
Val -- Loss: 0.004

Epoch: 6/10 -- Epoch Time: 70.75 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 7/10 -- Epoch Time: 71.27 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 8/10 -- Epoch Time: 74.62 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.003

Epoch: 9/10 -- Epoch Time: 73.77 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 10

  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 1/10 -- Epoch Time: 77.75 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 2/10 -- Epoch Time: 74.66 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 3/10 -- Epoch Time: 76.26 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 4/10 -- Epoch Time: 74.85 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 5/10 -- Epoch Time: 78.04 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 6/10 -- Epoch Time: 76.56 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 7/10 -- Epoch Time: 76.06 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 8/10 -- Epoch Time: 75.86 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.005

Epoch: 9/10 -- Epoch Time: 81.17 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 10

  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 2/10 -- Epoch Time: 82.70 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.006

Epoch: 3/10 -- Epoch Time: 80.37 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.004

Epoch: 4/10 -- Epoch Time: 72.31 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 5/10 -- Epoch Time: 75.95 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 6/10 -- Epoch Time: 72.88 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 7/10 -- Epoch Time: 74.96 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 8/10 -- Epoch Time: 76.32 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 9/10 -- Epoch Time: 77.26 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 10/10 -- Epoch Time: 79.58 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.007

Epoch: 1

  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 4/10 -- Epoch Time: 78.71 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.005

Epoch: 5/10 -- Epoch Time: 77.13 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.005

Epoch: 6/10 -- Epoch Time: 76.88 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 7/10 -- Epoch Time: 76.76 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003


  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 8/10 -- Epoch Time: 77.31 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 9/10 -- Epoch Time: 77.54 s
---------------------------------
Train -- Loss: 0.003
Val -- Loss: 0.007

Epoch: 10/10 -- Epoch Time: 75.72 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.005

Epoch: 1/10 -- Epoch Time: 75.44 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 2/10 -- Epoch Time: 76.52 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 3/10 -- Epoch Time: 76.23 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 4/10 -- Epoch Time: 75.66 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 5/10 -- Epoch Time: 77.87 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 6/10 -- Epoch Time: 75.70 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.007


  return xyz2rgb(lab2xyz(lab, illuminant, observer))
  return xyz2rgb(lab2xyz(lab, illuminant, observer))



Epoch: 7/10 -- Epoch Time: 76.21 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.005

Epoch: 8/10 -- Epoch Time: 73.33 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 9/10 -- Epoch Time: 76.33 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 10/10 -- Epoch Time: 75.30 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.004

Epoch: 1/10 -- Epoch Time: 76.68 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.005

Epoch: 2/10 -- Epoch Time: 75.98 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 3/10 -- Epoch Time: 75.12 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.003

Epoch: 4/10 -- Epoch Time: 72.85 s
---------------------------------
Train -- Loss: 0.002
Val -- Loss: 0.006

Epoch: 5/10 -- Epoch Time: 78.58 s
---------------------------------
Train -- Loss: 0.001
Val -- Loss: 0.004

Epoch: 6

<Figure size 432x288 with 0 Axes>

### Let's see the results

In [20]:
out_transforms = transforms.Compose([transforms.Resize(120),
                                     transforms.CenterCrop(100)])
out_imagefolder = GrayscaleImageFolder('out/', val_transforms)
out_loader = torch.utils.data.DataLoader(out_imagefolder, batch_size=64, shuffle=False)

In [22]:
# Validate
save_images = True
with torch.no_grad():
  evaluate(model, val_loader, criterion, save_images, epoch = 140)

<Figure size 432x288 with 0 Axes>