In [0]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
from time import sleep

In [0]:
device = torch.device('cuda:0')

In [0]:
!wget http://sereja.me/f/universum_compressed.tar
!tar xf universum_compressed.tar

--2019-02-12 09:18:45--  http://sereja.me/f/universum_compressed.tar
Resolving sereja.me (sereja.me)... 213.159.215.132
Connecting to sereja.me (sereja.me)|213.159.215.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 72028160 (69M) [application/x-tar]
Saving to: ‘universum_compressed.tar’


2019-02-12 09:19:05 (3.48 MB/s) - ‘universum_compressed.tar’ saved [72028160/72028160]



In [30]:
!tar xf roi.tar

tar: Unexpected EOF in archive
tar: rmtlseek not stopped at a record boundary
tar: Error is not recoverable: exiting now


In [0]:
def to_numpy_image(img):
    return img.detach().cpu().view(3, 128, 128).transpose(0, 1).transpose(1, 2).numpy()

In [0]:
import os
from PIL import Image

class ColorizationDataset(Dataset):
    def __init__(self, path, transform_x, transform_y):
        self.transform_x = transform_x
        self.transform_y = transform_y
      
        filenames = []
        for root, dirs, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.JPG'):
                    filenames.append(os.path.join(root, file))

        self.images = []
        for filename in tqdm(filenames):
            try:
                with Image.open(filename) as image:
                    self.images.append(image.copy())
            except:
                pass

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        Y = self.transform_y(img)
        X = self.transform_x(Y)
        return X, Y

In [0]:
transform_all = transforms.Compose([
    transforms.RandomResizedCrop(128),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

def to_grayscale(x):
    return 1 - (x[0] * 0.299 + x[1] * 0.587 + x[2] * 0.114).view(1, 128, 128)

In [0]:
dataset = ColorizationDataset('universum-photos', to_grayscale, transform_all)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

100%|██████████| 1254/1254 [00:01<00:00, 686.43it/s]


In [31]:
dataset2 = ColorizationDataset('roi', to_grayscale, transform_all)

100%|██████████| 66/66 [00:05<00:00, 13.00it/s]


In [0]:
def Block_pool(c_in, c_out):
    return nn.Sequential(
          nn.Conv2d(c_in, c_out, (3, 3), padding=1),
          nn.MaxPool2d(2),
          nn.ReLU(),
    )

def Block_up(c_in, c_out):
    return nn.Sequential(
          nn.Conv2d(c_in, c_out, (3, 3), padding=1),
          nn.Upsample(scale_factor=2),
          nn.ReLU(),
    )

def Block_relu(c_in, c_out):
    return nn.Sequential(
          nn.Conv2d(c_in, c_out, (3, 3), padding=1),
          nn.ReLU(),
    )


def Block_sigm(c_in, c_out):
    return nn.Sequential(
          nn.Conv2d(c_in, c_out, (3, 3), padding=1),
          nn.Sigmoid(),
    )

In [0]:
class Colorizer(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.preconcat = nn.Sequential(
            Block_pool(1, 32),
            Block_pool(32, 64),
            Block_pool(64, 128),
            Block_relu(128, 256),
            Block_relu(256, 256), 
            Block_relu(256, 256), 
            Block_up(256, 256),
            Block_up(256, 128),
            Block_up(128, 64),
        )
        
        self.postconcat = nn.Sequential(
            Block_relu(65, 32), 
            Block_sigm(32, 3),
        )
    
    def forward(self, x):
        h = self.preconcat(x)
        h = torch.cat((h, x), 1)
        h = self.postconcat(h)
        return h

In [0]:
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse

E: Package 'python-software-properties' has no installation candidate
Selecting previously unselected package google-drive-ocamlfuse.
(Reading database ... 113597 files and directories currently installed.)
Preparing to unpack .../google-drive-ocamlfuse_0.7.1-0ubuntu3~ubuntu18.04.1_amd64.deb ...
Unpacking google-drive-ocamlfuse (0.7.1-0ubuntu3~ubuntu18.04.1) ...
Setting up google-drive-ocamlfuse (0.7.1-0ubuntu3~ubuntu18.04.1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...


In [0]:
from google.colab import auth
auth.authenticate_user()

In [0]:
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
··········
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
Please enter the verification code: Access token retrieved correctly.


In [0]:
!mkdir -p godrive
!google-drive-ocamlfuse godrive

!ls godrive

'Colab Notebooks'   Olymp.ods   saves   win.ods   Документы


In [0]:
epoch = 0

model = Colorizer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.L1Loss()

checkpoint = torch.load('godrive/saves/cnn_model_2100')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
criterion = checkpoint['criterion']

In [0]:
while True:
    losses = 0
    for x, y in loader:    
        x = x.to(device)
        y = y.to(device)
       
        optimizer.zero_grad()
        
        output = model(x)
        
        loss = criterion(output, y)
        loss.backward()
        
        losses += loss.item()
        
        optimizer.step()
    
    losses /= len(loader)
    
    print("Epoch: {} ; Loss: {}".format(epoch, losses))
    
    if epoch % 100 == 0:
      torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'criterion': criterion,
              }, 'godrive/saves/cnn_model_' + str(epoch))
    
    epoch += 1



Epoch: 1701 ; Loss: 0.02892567601520568
Epoch: 1702 ; Loss: 0.028645427664741874
Epoch: 1703 ; Loss: 0.029179743374697864
Epoch: 1704 ; Loss: 0.028089103987440467


In [0]:
def show_pic(cur_dataset):
    t = np.random.randint(len(cur_dataset))
    img_gray, img_true = cur_dataset[t]
    img_pred = model(img_gray.to(device).view(1, 1, 128, 128))
    img_pred = to_numpy_image(img_pred[0])
    img_true = to_numpy_image(img_true)
    plt.figure(figsize=(20, 20))

    plt.subplot(151)
    plt.axis('off')
    plt.set_cmap('Greys')
    plt.imshow(img_gray.reshape((128, 128)))

    plt.subplot(153)
    plt.axis('off')
    plt.imshow(img_pred.reshape((128, 128, 3)))

    plt.subplot(155)
    plt.axis('off')
    plt.imshow(img_true.reshape((128, 128, 3)))

    plt.show()

In [0]:
for i in range(20):
    show_pic(dataset)

Output hidden; open in https://colab.research.google.com to view.

In [32]:
for i in range(20):
    show_pic(dataset2)

Output hidden; open in https://colab.research.google.com to view.