In [5]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# dataset donwload
from torchvision import datasets
# data_folder = '~/cifar10/cifar/' 
data_folder = '~/workspace/data/cifar'
# datasets.CIFAR10(data_folder, download=True)

# 
import torchvision
import numpy as np
from torch.utils.data import DataLoader 
class Colorize(torchvision.datasets.CIFAR10):
    def __init__(self, root, train):
        super().__init__(root, train)
        
    def __getitem__(self, ix):
        im, _ = super().__getitem__(ix)
        bw = im.convert('L').convert('RGB')
        bw, im = np.array(bw)/255., np.array(im)/255. 
        bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]
        return bw, im

trdt = Colorize(data_folder, train=True)
valdt = Colorize(data_folder, train=False)

trdl = DataLoader(trdt, batch_size=256, shuffle=True)
valdl = DataLoader(valdt, batch_size=256, shuffle=False)

from model import get_model

model = get_model('unet')

from torch import nn
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.1)

In [3]:
import torch
from tqdm import tqdm
def train(dl,model,lossf,opt,device='cuda'):
    model.train()
    for x,y in tqdm(dl):
        x,y = x.to(device),y.to(device)
        opt.zero_grad()
        pre = model(x)
        loss = lossf(pre,y)
        loss.backward()
        opt.step()

@torch.no_grad()
def test(dl,model,lossf,epoch=None,exist_acc=True,device='cuda'):
    model.eval()
    size, acc , losses = len(dl.dataset) ,0,0
    with torch.no_grad():
        for x,y in tqdm(dl):
            x,y = x.to(device),y.to(device)
            pre = model(x)
            loss = lossf(pre,y)
            
            if exist_acc: 
                acc += (pre.argmax(1)==y).type(torch.float).sum().item()
            losses += loss.item()
    if exist_acc:
        accuracy = round(acc/size,4)
    else:
        accuracy = None
    val_loss = round(losses/size,6)
    print(f'[{epoch}] acc/loss: {accuracy}/{val_loss}' if exist_acc else f'[{epoch}] loss: {val_loss}')
    return accuracy,val_loss 

import copy
def run(trdl,valdl,model,loss,opt,sched,epoch=100,patience = 5,exist_acc=False,device='cuda'):
    val_losses = {0:1}
    model = model.to(device)
    for i in range(epoch):
        train(trdl,model,loss,opt,device=device)
        acc,val_loss = test(valdl,model,loss,epoch=i,exist_acc=exist_acc,device=device)


        if min(val_losses.values() ) > val_loss:
            val_losses[i] = val_loss
            best_model = copy.deepcopy(model)
        if i == min(val_losses,key=val_losses.get)+patience:
            break
            
        sched.step()
    return best_model,val_losses


In [6]:
best_model,val_losses = run(trdl,valdl,model,loss_fn,opt,scheduler,epoch=100,device=device) #,exist_acc=config['exist_acc'])

100%|██████████| 196/196 [00:32<00:00,  5.95it/s]
100%|██████████| 40/40 [00:03<00:00, 12.11it/s]


[0] loss: 2.5e-05


100%|██████████| 196/196 [00:33<00:00,  5.90it/s]
100%|██████████| 40/40 [00:03<00:00, 12.27it/s]


[1] loss: 2.3e-05


100%|██████████| 196/196 [00:33<00:00,  5.93it/s]
100%|██████████| 40/40 [00:03<00:00, 12.33it/s]


[2] loss: 2.1e-05


100%|██████████| 196/196 [00:32<00:00,  5.94it/s]
100%|██████████| 40/40 [00:03<00:00, 12.39it/s]


[3] loss: 2.1e-05


100%|██████████| 196/196 [00:32<00:00,  5.95it/s]
100%|██████████| 40/40 [00:03<00:00, 12.20it/s]


[4] loss: 2.4e-05


100%|██████████| 196/196 [00:32<00:00,  5.94it/s]
100%|██████████| 40/40 [00:03<00:00, 12.44it/s]


[5] loss: 2.3e-05


100%|██████████| 196/196 [00:33<00:00,  5.93it/s]
100%|██████████| 40/40 [00:03<00:00, 12.35it/s]


[6] loss: 2.2e-05


100%|██████████| 196/196 [00:33<00:00,  5.92it/s]
100%|██████████| 40/40 [00:03<00:00, 12.36it/s]


[7] loss: 2e-05


100%|██████████| 196/196 [00:33<00:00,  5.94it/s]
100%|██████████| 40/40 [00:03<00:00, 12.25it/s]


[8] loss: 2.4e-05


100%|██████████| 196/196 [00:33<00:00,  5.90it/s]
100%|██████████| 40/40 [00:03<00:00, 12.40it/s]


[9] loss: 2.1e-05


100%|██████████| 196/196 [00:33<00:00,  5.93it/s]
100%|██████████| 40/40 [00:03<00:00, 12.21it/s]


[10] loss: 1.9e-05


100%|██████████| 196/196 [00:33<00:00,  5.90it/s]
100%|██████████| 40/40 [00:03<00:00, 12.21it/s]


[11] loss: 1.9e-05


100%|██████████| 196/196 [00:33<00:00,  5.90it/s]
100%|██████████| 40/40 [00:03<00:00, 12.22it/s]


[12] loss: 1.9e-05


100%|██████████| 196/196 [00:32<00:00,  5.95it/s]
100%|██████████| 40/40 [00:03<00:00, 12.45it/s]


[13] loss: 1.9e-05


100%|██████████| 196/196 [00:32<00:00,  5.94it/s]
100%|██████████| 40/40 [00:03<00:00, 12.41it/s]


[14] loss: 1.9e-05


100%|██████████| 196/196 [00:32<00:00,  5.94it/s]
100%|██████████| 40/40 [00:03<00:00, 12.30it/s]

[15] loss: 1.9e-05





In [7]:
best_model

UNet(
  (d1): DownConv(
    (model): Sequential(
      (0): Identity()
      (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (d2): DownConv(
    (model): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
      (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, mo