In [246]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
import torchvision.models as models

import os, time, shutil, argparse
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, rgb2gray, lab2rgb
from skimage import io

from functools import partial
import pickle


In [247]:
class ColorizationNet(nn.Module):
    def __init__(self, midlevel_input_size=128, global_input_size=512):
        super(ColorizationNet, self).__init__()
        # Fusion layer to combine midlevel and global features
        self.midlevel_input_size = midlevel_input_size
        self.global_input_size = global_input_size
        self.fusion = nn.Linear(midlevel_input_size + global_input_size, midlevel_input_size)
        self.bn1 = nn.BatchNorm1d(midlevel_input_size)

        # Convolutional layers and upsampling
        self.deconv1_new = nn.ConvTranspose2d(midlevel_input_size, 128, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.Conv2d(midlevel_input_size, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.conv5 = nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

        print('Loaded colorization net.')

    def forward(self, midlevel_input): #, global_input):
        
        # Convolutional layers and upsampling
        x = F.relu(self.bn2(self.conv1(midlevel_input)))
        x = self.upsample(x)
        x = F.relu(self.bn3(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.upsample(x)
        x = F.sigmoid(self.conv4(x))
        x = self.upsample(self.conv5(x))
        return x


class ColorNet(nn.Module):
    def __init__(self):
        super(ColorNet, self).__init__()
        
        # Build ResNet and change first conv layer to accept single-channel input
        resnet_gray_model = models.resnet18(num_classes=365)
        resnet_gray_model.conv1.weight = nn.Parameter(resnet_gray_model.conv1.weight.sum(dim=1).unsqueeze(1).data)
        
        # Only needed if not resuming from a checkpoint: load pretrained ResNet-gray model
        if torch.cuda.is_available(): # and only if gpu is available
            resnet_gray_weights = torch.load('pretrained/resnet_gray_weights.pth.tar') #torch.load('pretrained/resnet_gray.tar')['state_dict']
            resnet_gray_model.load_state_dict(resnet_gray_weights)
            print('Pretrained ResNet-gray weights loaded')

        # Extract midlevel and global features from ResNet-gray
        self.midlevel_resnet = nn.Sequential(*list(resnet_gray_model.children())[0:6])
        self.global_resnet = nn.Sequential(*list(resnet_gray_model.children())[0:9])
        self.fusion_and_colorization_net = ColorizationNet()

    def forward(self, input_image):

        # Pass input through ResNet-gray to extract features
        midlevel_output = self.midlevel_resnet(input_image)
        # global_output = self.global_resnet(input_image)

        # Combine features in fusion layer and upsample
        output = self.fusion_and_colorization_net(midlevel_output) #, global_output)
        return output

## Utils

In [248]:
plt.switch_backend('agg')

def save_checkpoint(state, is_best_so_far, filename='checkpoints/checkpoint.pth.tar'):
    '''Saves checkpoint, and replace the old best model if the current model is better'''
    torch.save(state, filename)
    if is_best_so_far:
        shutil.copyfile(filename, 'checkpoints/model_best.pth.tar')

class AverageMeter(object):
    '''An easy way to compute and store both average and current values'''
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def visualize_image(grayscale_input, ab_input=None, show_image=False, save_path=None, save_name=None):
    '''Show or save image given grayscale (and ab color) inputs. Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
    plt.clf() # clear matplotlib plot
    ab_input = ab_input.cpu()
    grayscale_input = grayscale_input.cpu()    
    if ab_input is None:
        grayscale_input = grayscale_input.squeeze().numpy() 
        if save_path is not None and save_name is not None: 
            plt.imsave(grayscale_input, '{}.{}'.format(save_path['grayscale'], save_name) , cmap='gray')
        if show_image: 
            plt.imshow(grayscale_input, cmap='gray')
            plt.show()
    else: 
        color_image = torch.cat((grayscale_input, ab_input), 0).numpy()
        color_image = color_image.transpose((1, 2, 0))  
        color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
        color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
        color_image = lab2rgb(color_image.astype(np.float64))
        grayscale_input = grayscale_input.squeeze().numpy()
        if save_path is not None and save_name is not None:
            plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
            plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))
        if show_image: 
            f, axarr = plt.subplots(1, 2)
            axarr[0].imshow(grayscale_input, cmap='gray')
            axarr[1].imshow(color_image)
            plt.show()

class GrayscaleImageFolder(datasets.ImageFolder):
    '''Custom images folder, which converts images to grayscale before loading'''
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img_original = self.transform(img)
            img_original = np.asarray(img_original)
            img_lab = rgb2lab(img_original)
            img_lab = (img_lab + 128) / 255
            img_ab = img_lab[:, :, 1:3]
            img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
            img_original = rgb2gray(img_original)
            img_original = torch.from_numpy(img_original).unsqueeze(0).float()
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img_original, img_ab, target

## Main

In [233]:
def train(train_loader, model, criterion, optimizer, epoch):
    '''Train model on data in train_loader for a single epoch'''
    print('Starting training epoch {}'.format(epoch))

    # Prepare value counters and timers
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    
    # Switch model to train mode
    model.train()
    
    # Train for single eopch
    end = time.time()
    for i, (input_gray, input_ab, target) in enumerate(train_loader):
        
        # Use GPU if available
        input_gray_variable = Variable(input_gray).cuda() if use_gpu else Variable(input_gray)
        input_ab_variable = Variable(input_ab).cuda() if use_gpu else Variable(input_ab)
        target_variable = Variable(target).cuda() if use_gpu else Variable(target)

        # Record time to load data (above)
        data_time.update(time.time() - end)

        # Run forward pass
        output_ab = model(input_gray_variable) # throw away class predictions
        loss = criterion(output_ab, input_ab_variable) # MSE
        
        # Record loss and measure accuracy
        losses.update(loss.data[0], input_gray.size(0))
        
        # Compute gradient and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Record time to do forward and backward passes
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Print model accuracy -- in the code below, val refers to value, not validation
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses)) 

    print('Finished training epoch {}'.format(epoch))

def validate(val_loader, model, criterion, save_images, epoch):
    '''Validate model on data in val_loader'''
    print('Starting validation.')

    # Prepare value counters and timers
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    
    # Switch model to validation mode
    model.eval()
    
    # Run through validation set
    end = time.time()
    for i, (input_gray, input_ab, target) in enumerate(val_loader):
        
        # Use GPU if available
        target = target.cuda() if use_gpu else target
        input_gray_variable = Variable(input_gray, volatile=True).cuda() if use_gpu else Variable(input_gray, volatile=True)
        input_ab_variable = Variable(input_ab, volatile=True).cuda() if use_gpu else Variable(input_ab, volatile=True)
        target_variable = Variable(target, volatile=True).cuda() if use_gpu else Variable(target, volatile=True)
       
        # Record time to load data (above)
        data_time.update(time.time() - end)

        # Run forward pass
        output_ab = model(input_gray_variable) # throw away class predictions
        loss = criterion(output_ab, input_ab_variable) # check this!
        
        # Record loss and measure accuracy
        losses.update(loss.data[0], input_gray.size(0))

        # Save images to file
        if save_images:
            for j in range(len(output_ab)):
                save_path = {'grayscale': os.path.join(args.data, 'outputs/gray/'), 'colorized': os.path.join(args.data, 'outputs/color/')}
                save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)
                visualize_image(input_gray[j], ab_input=output_ab[j].data, show_image=False, save_path=save_path, save_name=save_name)

        # Record time to do forward passes and save images
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Print model accuracy -- in the code below, val refers to both value and validation
        if i % args.print_freq == 0:
            print('Validate: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses))

    print('Finished validation.')
    return losses.avg

# if __name__ == '__main__':
#     main()

In [249]:
# Parse arguments and prepare program
parser = argparse.ArgumentParser(description='Training and Using ColorNet')
parser.add_argument('-data', default='', type=str, metavar='DIR', help='path to dataset')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 0)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to .pth file checkpoint (default: none)')
parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (overridden if loading from checkpoint)')
parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N', help='size of mini-batch (default: 16)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='learning rate at start of training')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='use this flag to validate without training')
parser.add_argument('--print-freq', '-p', default=1, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('-f', '--fuck', default='', type=str, metavar='N', help='fuck off the -f')

_StoreAction(option_strings=['-f', '--fuck'], dest='fuck', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='fuck off the -f', metavar='N')

In [250]:
# Current best losses
best_losses = 1000.0
use_gpu = torch.cuda.is_available()

In [251]:
global args, best_losses, use_gpu
args = parser.parse_args()
print('Arguments: {}'.format(args))

Arguments: Namespace(batch_size=16, data='./', epochs=10, evaluate=False, fuck='/Users/aaronwg/Library/Jupyter/runtime/kernel-41363dec-d660-413d-954b-d0545c4c4d2b.json', lr=0.1, print_freq=1, resume='', start_epoch=0, weight_decay=0.0001, workers=0)


In [252]:
# Create model  # models.resnet18(num_classes=365)
model = ColorNet()
    
# Use GPU if available
if use_gpu:
    model.cuda()
    print('Loaded model onto GPU.')
 

Loaded colorization net.


In [253]:
   
# Create loss function, optimizer #criterion = nn.CrossEntropyLoss().cuda() if use_gpu else nn.CrossEntropyLoss()
criterion = nn.MSELoss().cuda() if use_gpu else nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

In [257]:
# Load data from pre-defined (imagenet-style) structure
if not args.evaluate:
    train_directory = os.path.join(args.data, 'train')
    train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip()
    ])
    train_imagefolder = GrayscaleImageFolder(train_directory, train_transforms)
    train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    print('Loaded training data.')
val_transforms = transforms.Compose([
    transforms.Resize((256,256)),
])
val_directory = os.path.join(args.data, 'val')
val_imagefolder = GrayscaleImageFolder(val_directory , val_transforms)
val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print('Loaded validation data.')

Loaded training data.
Loaded validation data.


In [244]:
# If in evaluation (validation) mode, do not train
if args.evaluate:
    save_images = True
    epoch = 0
    initial_losses = validate(val_loader, model, criterion, save_images, epoch)

        # # Save checkpoint after evaluation if desired
        # save_checkpoint({
        #     'epoch': epoch,
        #     'best_losses': initial_losses,
        #     'state_dict': model.state_dict(),
        #     'optimizer': optimizer.state_dict(),
        # }, False, 'checkpoints/evaluate-checkpoint.pth.tar')
        
#        return  
    
# Otherwise, train for given number of epochs
validate(val_loader, model, criterion, False, 0) # validate before training

Starting validation.
Validate: [0/1]	Time 0.216 (0.216)	Loss 0.4959 (0.4959)	
Finished validation.




tensor(0.4959)

In [258]:
if args.evaluate:
    save_images = True
    epoch = 0
    initial_losses = validate(val_loader, model, criterion, save_images, epoch)

    # # Save checkpoint after evaluation if desired
    # save_checkpoint({
    #     'epoch': epoch,
    #     'best_losses': initial_losses,
    #     'state_dict': model.state_dict(),
    #     'optimizer': optimizer.state_dict(),
    # }, False, 'checkpoints/evaluate-checkpoint.pth.tar')
else:
    for epoch in range(args.start_epoch, args.epochs):
        
        # Train for one epoch, then validate
        train(train_loader, model, criterion, optimizer, epoch)
        if epoch % 10 == 0:
            save_images = True
            losses = validate(val_loader, model, criterion, save_images, epoch)
        
        # Save checkpoint, and replace the old best model if the current model is better
        is_best_so_far = losses < best_losses
        best_losses = max(losses, best_losses)
        if is_best_so_far:
            save_checkpoint({
                'epoch': epoch + 1,
                'best_losses': best_losses,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best_so_far, os.path.join(args.data, 'checkpoints/checkpoint-epoch-{}.pth.tar'.format(epoch)))

Starting training epoch 0




Epoch: [0][0/2]	Time 8.578 (8.578)	Data 0.424 (0.424)	Loss 0.2348 (0.2348)	
Epoch: [0][1/2]	Time 24.975 (16.777)	Data 0.366 (0.395)	Loss 189.5527 (94.8937)	
Finished training epoch 0
Starting validation.


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)


Validate: [0/1]	Time 0.269 (0.269)	Loss 15.0242 (15.0242)	
Finished validation.
Starting training epoch 1
Epoch: [1][0/2]	Time 28.231 (28.231)	Data 0.347 (0.347)	Loss 15.0969 (15.0969)	
Epoch: [1][1/2]	Time 31.788 (30.010)	Data 0.404 (0.376)	Loss 34.2928 (24.6949)	
Finished training epoch 1
Starting validation.
Validate: [0/1]	Time 0.260 (0.260)	Loss 104.5425 (104.5425)	
Finished validation.
Starting training epoch 2
Epoch: [2][0/2]	Time 29.388 (29.388)	Data 0.362 (0.362)	Loss 103.8999 (103.8999)	
Epoch: [2][1/2]	Time 29.240 (29.314)	Data 0.370 (0.366)	Loss 86.3789 (95.1394)	
Finished training epoch 2
Starting validation.
Validate: [0/1]	Time 0.347 (0.347)	Loss 28.4680 (28.4680)	
Finished validation.
Starting training epoch 3
Epoch: [3][0/2]	Time 19.535 (19.535)	Data 0.406 (0.406)	Loss 28.5799 (28.5799)	
Epoch: [3][1/2]	Time 27.848 (23.692)	Data 0.393 (0.400)	Loss 8.7702 (18.6750)	
Finished training epoch 3
Starting validation.
Validate: [0/1]	Time 0.273 (0.273)	Loss 35.5792 (35.5792)	