In [12]:
import glob
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
import torch.nn as nn
import tqdm

from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam
from torch.utils.data.dataset import Dataset
from PIL import Image

In [25]:
## RGB를 LAB로 변환하는 함수
def rgb2lab(rgb):
    return cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)

## LAB를 RGB로 변환하는 함수
def lab2rgb(lab):
    return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

## 데이터셋 정의 클래스
class AutoColoring(Dataset):
    def __init__(self):
        self.data = glob.glob("./data/archive/img_align_celeba/img_align_celeba/*.jpg")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        rgb = np.array(Image.open(self.data[i]).resize((256, 256)))
        lab = rgb2lab(rgb)
        lab = lab.transpose((2, 0, 1)).astype(np.float32) # 파이토치는 채널이 가장 앞에 와야 하므로 transpose를 사용했습니다.

        return lab[0], lab[1:]

In [26]:
## Low level 특징 추출하는 클래스
class LowLevel(nn.Module):
    def __init__(self):
        super(LowLevel, self).__init__()

        self.low1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.lb1 = nn.BatchNorm2d(64)
        self.low2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.lb2 = nn.BatchNorm2d(128)
        self.low3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.lb3 = nn.BatchNorm2d(128)
        self.low4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.lb4 = nn.BatchNorm2d(256)
        self.low5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.lb5 = nn.BatchNorm2d(256)
        self.low6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.lb6 = nn.BatchNorm2d(512)

        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        low = self.low1(x)
        low = self.lb1(low)
        low = self.sigmoid(low)

        low = self.low2(low)
        low = self.lb2(low)
        low = self.sigmoid(low)

        low = self.low3(low)
        low = self.lb3(low)
        low = self.sigmoid(low)

        low = self.low4(low)
        low = self.lb4(low)
        low = self.sigmoid(low)

        low = self.low5(low)
        low = self.lb5(low)
        low = self.sigmoid(low)

        low = self.low6(low)
        low = self.lb6(low)
        low = self.sigmoid(low)

        return low
    
## Middle level 특징 추출하는 클래스
class MidLevel(nn.Module):
    def __init__(self):
        super(MidLevel, self).__init__()

        self.mid1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.mb1 = nn.BatchNorm2d(512)
        self.mid2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.mb2 = nn.BatchNorm2d(256)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        mid = self.mid1(x)
        mid = self.mb1(mid)
        mid = self.sigmoid(mid)

        mid = self.mid2(mid)
        mid = self.mb2(mid)
        mid = self.sigmoid(mid)

        return mid

## Global level 특징 추출하는 클래스
class GlobalLevel(nn.Module):
    def __init__(self):
        super(GlobalLevel, self).__init__()

        self.glob1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
        self.gb1 = nn.BatchNorm2d(512)
        self.glob2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.gb2 = nn.BatchNorm2d(512)
        self.glob3 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)
        self.gb3 = nn.BatchNorm2d(512)
        self.glob4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.gb4 = nn.BatchNorm2d(512)

        # 글로벌 레벨 특징 추출기의 MLP층 구성하여 색을 칠하는 특징으로 사용합니다.
        self.fc1 = nn.Linear(in_features=32768, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=256)

        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        glo = self.glob1(x)
        glo = self.gb1(glo)
        glo = self.sigmoid(glo)

        glo = self.glob2(glo)
        glo = self.gb2(glo)
        glo = self.sigmoid(glo)

        glo = self.glob3(glo)
        glo = self.gb3(glo)
        glo = self.sigmoid(glo)

        glo = self.glob4(glo)
        glo = self.gb4(glo)
        glo = self.sigmoid(glo)

        # 추출된 특징을 1차원으로 펼쳐줍니다.
        glo = torch.flatten(glo, start_dim=1)
        glo = self.fc1(glo)
        glo = self.sigmoid(glo)
        glo = self.fc2(glo)
        glo = self.sigmoid(glo)
        glo = self.fc3(glo)
        glo = self.sigmoid(glo)

        return glo
    
## Colorization Network 클래스
class Colorization(nn.Module):
    def __init__(self):
        super(Colorization, self).__init__()
        
        self.color1 = nn.ConvTranspose2d(256, 128, 3, 1, 1)
        self.cb1 = nn.BatchNorm2d(128)
        self.color2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.cb2 = nn.BatchNorm2d(64)
        self.color3 = nn.ConvTranspose2d(64, 64, 3, 1, 1)
        self.cb3 = nn.BatchNorm2d(64)
        self.color4 = nn.ConvTranspose2d(64, 32, 2, 2)
        self.cb4 = nn.BatchNorm2d(32)
        self.color5 = nn.ConvTranspose2d(32, 2, 2, 2)

        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        color = self.color1(x)
        color = self.cb1(color)
        color = self.sigmoid(color)
        color = self.color2(color)
        color = self.cb2(color)
        color = self.sigmoid(color)
        color = self.color3(color)
        color = self.cb3(color)
        color = self.sigmoid(color)
        color = self.color4(color)
        color = self.cb4(color)
        color = self.sigmoid(color)
        color = self.color5(color)

        return color

## AutoColoring 모델 정의
class AutoColoringModel(nn.Module):
    def __init__(self):
        super(AutoColoringModel, self).__init__()

        self.low = LowLevel()
        self.mid = MidLevel()
        self.glob = GlobalLevel()
        self.fusion = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) # 특징 합치기 위한 것입니다.
        self.color = Colorization()
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        low = self.low(x)
        mid = self.mid(low)
        glo = self.glob(low)
        
        # 글로벌 레벨 특징 추출기의 출력을 미들 레벨 특징 추출기의 출력 크기가 되도록 반복합니다.
        fusion = glo.repeat(1, mid.shape[2]*mid.shape[2]) 
        fusion = torch.reshape(fusion, (-1, 256, mid.shape[2], mid.shape[2]))
        
        # 글로벌 레벨 특징 추출기의 특징과 미들 레벨 특징 추출기의 특징을 결합합니다.
        fusion = torch.cat([mid, fusion], dim=1)
        fusion = self.fusion(fusion)
        fusion = self.sigmoid(fusion)

        color = self.color(fusion)

        return color

In [27]:
## 모델 정의와 데이터 정의
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoColoringModel().to(DEVICE)
dataset = AutoColoring()
loader = DataLoader(dataset, batch_size=32, shuffle=True)
optim = Adam(params=model.parameters(), lr=0.01)

## 학습
for epoch in range(201):
    print(f'Epoch {epoch + 1} of {200}')
    
    iterator = tqdm.tqdm(loader)
    for L, AB in iterator:        
        L = torch.unsqueeze(L, dim=1).to(DEVICE) # L 채널은 흑백 이미지 이므로 채널 차원을 확보했습니다.
        optim.zero_grad()

        pred = model(L)
        loss = nn.MSELoss()(pred, AB.to(DEVICE))
        loss.backward()
        optim.step()
        
        iterator.set_description(f"epoch:{epoch} loss:{loss.item()}")

torch.save(model.state_dict(), "AutoColor.pth")

Epoch 1 of 200



  0%|                                                               | 0/6332 [00:00<?, ?it/s][A
epoch:0 loss:18614.537109375:   0%|                                 | 0/6332 [00:18<?, ?it/s][A
epoch:0 loss:18614.537109375:   0%|                      | 1/6332 [00:18<32:19:23, 18.38s/it][A
epoch:0 loss:18438.421875:   0%|                         | 1/6332 [00:31<32:19:23, 18.38s/it][A
epoch:0 loss:18438.421875:   0%|                         | 2/6332 [00:31<26:51:59, 15.28s/it][A
epoch:0 loss:18453.267578125:   0%|                      | 2/6332 [00:48<26:51:59, 15.28s/it][A
epoch:0 loss:18453.267578125:   0%|                      | 3/6332 [00:48<28:25:19, 16.17s/it][A
epoch:0 loss:18970.4140625:   0%|                        | 3/6332 [01:01<28:25:19, 16.17s/it][A
epoch:0 loss:18970.4140625:   0%|                        | 4/6332 [01:01<26:07:44, 14.86s/it][A
epoch:0 loss:18674.646484375:   0%|                      | 4/6332 [01:14<26:07:44, 14.86s/it][A
epoch:0 loss:18674.646484375:

epoch:0 loss:13174.931640625:   1%|▎                    | 84/6332 [18:15<22:09:03, 12.76s/it][A
epoch:0 loss:13400.3994140625:   1%|▎                   | 84/6332 [18:27<22:09:03, 12.76s/it][A
epoch:0 loss:13400.3994140625:   1%|▎                   | 85/6332 [18:27<21:49:05, 12.57s/it][A
epoch:0 loss:13675.9921875:   1%|▎                      | 85/6332 [18:40<21:49:05, 12.57s/it][A
epoch:0 loss:13675.9921875:   1%|▎                      | 86/6332 [18:40<21:56:44, 12.65s/it][A
epoch:0 loss:12912.0048828125:   1%|▎                   | 86/6332 [18:52<21:56:44, 12.65s/it][A
epoch:0 loss:12912.0048828125:   1%|▎                   | 87/6332 [18:52<21:20:45, 12.31s/it][A
epoch:0 loss:12953.576171875:   1%|▎                    | 87/6332 [19:09<21:20:45, 12.31s/it][A
epoch:0 loss:12953.576171875:   1%|▎                    | 88/6332 [19:09<24:02:29, 13.86s/it][A
epoch:0 loss:13196.998046875:   1%|▎                    | 88/6332 [19:21<24:02:29, 13.86s/it][A
epoch:0 loss:13196.998046875: 

epoch:0 loss:7568.6875:   3%|▋                         | 168/6332 [36:14<21:27:33, 12.53s/it][A
epoch:0 loss:7078.58837890625:   3%|▌                  | 168/6332 [36:26<21:27:33, 12.53s/it][A
epoch:0 loss:7078.58837890625:   3%|▌                  | 169/6332 [36:26<20:49:15, 12.16s/it][A
epoch:0 loss:7387.0439453125:   3%|▌                   | 169/6332 [36:39<20:49:15, 12.16s/it][A
epoch:0 loss:7387.0439453125:   3%|▌                   | 170/6332 [36:39<21:09:36, 12.36s/it][A
epoch:0 loss:7414.857421875:   3%|▌                    | 170/6332 [36:50<21:09:36, 12.36s/it][A
epoch:0 loss:7414.857421875:   3%|▌                    | 171/6332 [36:50<20:42:02, 12.10s/it][A
epoch:0 loss:7385.271484375:   3%|▌                    | 171/6332 [37:03<20:42:02, 12.10s/it][A
epoch:0 loss:7385.271484375:   3%|▌                    | 172/6332 [37:03<21:06:38, 12.34s/it][A
epoch:0 loss:7034.30908203125:   3%|▌                  | 172/6332 [37:20<21:06:38, 12.34s/it][A
epoch:0 loss:7034.30908203125:

epoch:0 loss:4003.468505859375:   4%|▋                 | 252/6332 [54:36<21:57:01, 13.00s/it][A
epoch:0 loss:3895.6806640625:   4%|▊                   | 252/6332 [54:54<21:57:01, 13.00s/it][A
epoch:0 loss:3895.6806640625:   4%|▊                   | 253/6332 [54:54<24:24:21, 14.45s/it][A
epoch:0 loss:3758.16748046875:   4%|▊                  | 253/6332 [55:12<24:24:21, 14.45s/it][A
epoch:0 loss:3758.16748046875:   4%|▊                  | 254/6332 [55:12<26:21:24, 15.61s/it][A
epoch:0 loss:3713.205322265625:   4%|▋                 | 254/6332 [55:23<26:21:24, 15.61s/it][A
epoch:0 loss:3713.205322265625:   4%|▋                 | 255/6332 [55:23<24:11:59, 14.34s/it][A
epoch:0 loss:3777.8447265625:   4%|▊                   | 255/6332 [55:35<24:11:59, 14.34s/it][A
epoch:0 loss:3777.8447265625:   4%|▊                   | 256/6332 [55:35<22:57:53, 13.61s/it][A
epoch:0 loss:3706.685791015625:   4%|▋                 | 256/6332 [55:47<22:57:53, 13.61s/it][A
epoch:0 loss:3706.685791015625

epoch:0 loss:2026.809814453125:   5%|▊               | 336/6332 [1:12:34<18:57:30, 11.38s/it][A
epoch:0 loss:1989.2716064453125:   5%|▊              | 336/6332 [1:12:46<18:57:30, 11.38s/it][A
epoch:0 loss:1989.2716064453125:   5%|▊              | 337/6332 [1:12:46<19:14:39, 11.56s/it][A
epoch:0 loss:1935.9178466796875:   5%|▊              | 337/6332 [1:12:58<19:14:39, 11.56s/it][A
epoch:0 loss:1935.9178466796875:   5%|▊              | 338/6332 [1:12:58<19:27:59, 11.69s/it][A
epoch:0 loss:1968.04296875:   5%|█                   | 338/6332 [1:13:10<19:27:59, 11.69s/it][A
epoch:0 loss:1968.04296875:   5%|█                   | 339/6332 [1:13:10<19:21:51, 11.63s/it][A
epoch:0 loss:1979.6119384765625:   5%|▊              | 339/6332 [1:13:21<19:21:51, 11.63s/it][A
epoch:0 loss:1979.6119384765625:   5%|▊              | 340/6332 [1:13:21<19:20:39, 11.62s/it][A
epoch:0 loss:1808.918701171875:   5%|▊               | 340/6332 [1:13:39<19:20:39, 11.62s/it][A
epoch:0 loss:1808.918701171875

epoch:0 loss:1020.0028686523438:   7%|▉              | 420/6332 [1:30:40<20:30:40, 12.49s/it][A
epoch:0 loss:1184.63818359375:   7%|█▏               | 420/6332 [1:30:51<20:30:40, 12.49s/it][A
epoch:0 loss:1184.63818359375:   7%|█▏               | 421/6332 [1:30:51<19:53:53, 12.12s/it][A
epoch:0 loss:974.7590942382812:   7%|█               | 421/6332 [1:31:03<19:53:53, 12.12s/it][A
epoch:0 loss:974.7590942382812:   7%|█               | 422/6332 [1:31:03<19:49:12, 12.07s/it][A
epoch:0 loss:1066.655029296875:   7%|█               | 422/6332 [1:31:14<19:49:12, 12.07s/it][A
epoch:0 loss:1066.655029296875:   7%|█               | 423/6332 [1:31:14<19:32:21, 11.90s/it][A
epoch:0 loss:1080.273681640625:   7%|█               | 423/6332 [1:31:26<19:32:21, 11.90s/it][A
epoch:0 loss:1080.273681640625:   7%|█               | 424/6332 [1:31:26<19:23:31, 11.82s/it][A
epoch:0 loss:883.941162109375:   7%|█▏               | 424/6332 [1:31:38<19:23:31, 11.82s/it][A
epoch:0 loss:883.941162109375:

epoch:0 loss:524.3849487304688:   8%|█▎              | 504/6332 [1:48:18<19:36:38, 12.11s/it][A
epoch:0 loss:627.931884765625:   8%|█▎               | 504/6332 [1:48:32<19:36:38, 12.11s/it][A
epoch:0 loss:627.931884765625:   8%|█▎               | 505/6332 [1:48:32<20:37:00, 12.74s/it][A
epoch:0 loss:549.4033813476562:   8%|█▎              | 505/6332 [1:48:43<20:37:00, 12.74s/it][A
epoch:0 loss:549.4033813476562:   8%|█▎              | 506/6332 [1:48:43<20:02:24, 12.38s/it][A
epoch:0 loss:669.6561279296875:   8%|█▎              | 506/6332 [1:48:55<20:02:24, 12.38s/it][A
epoch:0 loss:669.6561279296875:   8%|█▎              | 507/6332 [1:48:55<19:26:25, 12.01s/it][A
epoch:0 loss:527.51220703125:   8%|█▍                | 507/6332 [1:49:12<19:26:25, 12.01s/it][A
epoch:0 loss:527.51220703125:   8%|█▍                | 508/6332 [1:49:12<21:52:02, 13.52s/it][A
epoch:0 loss:609.91162109375:   8%|█▍                | 508/6332 [1:49:23<21:52:02, 13.52s/it][A
epoch:0 loss:609.91162109375: 

epoch:0 loss:373.8142395019531:   9%|█▍              | 588/6332 [2:06:47<22:07:39, 13.87s/it][A
epoch:0 loss:412.7818603515625:   9%|█▍              | 588/6332 [2:07:00<22:07:39, 13.87s/it][A
epoch:0 loss:412.7818603515625:   9%|█▍              | 589/6332 [2:07:00<21:26:11, 13.44s/it][A
epoch:0 loss:424.646728515625:   9%|█▌               | 589/6332 [2:07:15<21:26:11, 13.44s/it][A
epoch:0 loss:424.646728515625:   9%|█▌               | 590/6332 [2:07:15<22:16:06, 13.96s/it][A
epoch:0 loss:458.461669921875:   9%|█▌               | 590/6332 [2:07:28<22:16:06, 13.96s/it][A
epoch:0 loss:458.461669921875:   9%|█▌               | 591/6332 [2:07:28<21:46:27, 13.65s/it][A
epoch:0 loss:411.52215576171875:   9%|█▍             | 591/6332 [2:07:44<21:46:27, 13.65s/it][A
epoch:0 loss:411.52215576171875:   9%|█▍             | 592/6332 [2:07:44<22:58:31, 14.41s/it][A
epoch:0 loss:495.898681640625:   9%|█▌               | 592/6332 [2:07:56<22:58:31, 14.41s/it][A
epoch:0 loss:495.898681640625:

epoch:0 loss:328.28875732421875:  11%|█▌             | 672/6332 [2:24:40<18:48:43, 11.97s/it][A
epoch:0 loss:365.21868896484375:  11%|█▌             | 672/6332 [2:24:52<18:48:43, 11.97s/it][A
epoch:0 loss:365.21868896484375:  11%|█▌             | 673/6332 [2:24:52<18:36:23, 11.84s/it][A
epoch:0 loss:340.8138427734375:  11%|█▋              | 673/6332 [2:25:09<18:36:23, 11.84s/it][A
epoch:0 loss:340.8138427734375:  11%|█▋              | 674/6332 [2:25:09<21:22:16, 13.60s/it][A
epoch:0 loss:342.5919494628906:  11%|█▋              | 674/6332 [2:25:21<21:22:16, 13.60s/it][A
epoch:0 loss:342.5919494628906:  11%|█▋              | 675/6332 [2:25:21<20:20:00, 12.94s/it][A
epoch:0 loss:345.71484375:  11%|██▏                  | 675/6332 [2:25:32<20:20:00, 12.94s/it][A
epoch:0 loss:345.71484375:  11%|██▏                  | 676/6332 [2:25:32<19:34:51, 12.46s/it][A
epoch:0 loss:328.86102294921875:  11%|█▌             | 676/6332 [2:25:45<19:34:51, 12.46s/it][A
epoch:0 loss:328.8610229492187

epoch:0 loss:295.0091857910156:  12%|█▉              | 756/6332 [2:42:13<17:52:39, 11.54s/it][A
epoch:0 loss:259.5849914550781:  12%|█▉              | 756/6332 [2:42:25<17:52:39, 11.54s/it][A
epoch:0 loss:259.5849914550781:  12%|█▉              | 757/6332 [2:42:25<17:51:15, 11.53s/it][A
epoch:0 loss:303.45550537109375:  12%|█▊             | 757/6332 [2:42:36<17:51:15, 11.53s/it][A
epoch:0 loss:303.45550537109375:  12%|█▊             | 758/6332 [2:42:36<17:47:00, 11.49s/it][A
epoch:0 loss:311.4881896972656:  12%|█▉              | 758/6332 [2:42:47<17:47:00, 11.49s/it][A
epoch:0 loss:311.4881896972656:  12%|█▉              | 759/6332 [2:42:47<17:39:20, 11.41s/it][A
epoch:0 loss:297.64996337890625:  12%|█▊             | 759/6332 [2:43:00<17:39:20, 11.41s/it][A
epoch:0 loss:297.64996337890625:  12%|█▊             | 760/6332 [2:43:00<18:13:34, 11.78s/it][A
epoch:0 loss:256.279296875:  12%|██▍                 | 760/6332 [2:43:11<18:13:34, 11.78s/it][A
epoch:0 loss:256.279296875:  1

epoch:0 loss:295.2686767578125:  13%|██              | 840/6332 [2:59:48<19:26:58, 12.75s/it][A
epoch:0 loss:276.8915100097656:  13%|██              | 840/6332 [3:00:02<19:26:58, 12.75s/it][A
epoch:0 loss:276.8915100097656:  13%|██▏             | 841/6332 [3:00:02<20:06:58, 13.19s/it][A
epoch:0 loss:292.152099609375:  13%|██▎              | 841/6332 [3:00:13<20:06:58, 13.19s/it][A
epoch:0 loss:292.152099609375:  13%|██▎              | 842/6332 [3:00:13<19:10:51, 12.58s/it][A
epoch:0 loss:286.2403259277344:  13%|██▏             | 842/6332 [3:00:24<19:10:51, 12.58s/it][A
epoch:0 loss:286.2403259277344:  13%|██▏             | 843/6332 [3:00:24<18:31:19, 12.15s/it][A
epoch:0 loss:287.61651611328125:  13%|█▉             | 843/6332 [3:00:41<18:31:19, 12.15s/it][A
epoch:0 loss:287.61651611328125:  13%|█▉             | 844/6332 [3:00:41<20:29:59, 13.45s/it][A
epoch:0 loss:262.93231201171875:  13%|█▉             | 844/6332 [3:00:52<20:29:59, 13.45s/it][A
epoch:0 loss:262.9323120117187

epoch:0 loss:320.78009033203125:  15%|██▏            | 924/6332 [3:17:09<20:34:46, 13.70s/it][A
epoch:0 loss:237.5956573486328:  15%|██▎             | 924/6332 [3:17:20<20:34:46, 13.70s/it][A
epoch:0 loss:237.5956573486328:  15%|██▎             | 925/6332 [3:17:20<19:28:26, 12.97s/it][A
epoch:0 loss:259.312255859375:  15%|██▍              | 925/6332 [3:17:32<19:28:26, 12.97s/it][A
epoch:0 loss:259.312255859375:  15%|██▍              | 926/6332 [3:17:32<18:51:34, 12.56s/it][A
epoch:0 loss:265.0611267089844:  15%|██▎             | 926/6332 [3:17:47<18:51:34, 12.56s/it][A
epoch:0 loss:265.0611267089844:  15%|██▎             | 927/6332 [3:17:47<20:11:45, 13.45s/it][A
epoch:0 loss:288.2142333984375:  15%|██▎             | 927/6332 [3:17:59<20:11:45, 13.45s/it][A
epoch:0 loss:288.2142333984375:  15%|██▎             | 928/6332 [3:17:59<19:22:49, 12.91s/it][A
epoch:0 loss:248.84735107421875:  15%|██▏            | 928/6332 [3:18:10<19:22:49, 12.91s/it][A
epoch:0 loss:248.8473510742187

epoch:0 loss:114.89585876464844:  16%|█▌        | 1008/6332 [6:38:44<4787:05:19, 3236.95s/it][A
epoch:0 loss:172.7355499267578:  16%|█▊         | 1008/6332 [6:39:10<4787:05:19, 3236.95s/it][A
epoch:0 loss:172.7355499267578:  16%|█▊         | 1009/6332 [6:39:10<3361:41:33, 2273.55s/it][A
epoch:0 loss:135.4989776611328:  16%|█▊         | 1009/6332 [6:39:28<3361:41:33, 2273.55s/it][A
epoch:0 loss:135.4989776611328:  16%|█▊         | 1010/6332 [6:39:28<2360:33:57, 1596.78s/it][A
epoch:0 loss:149.6234130859375:  16%|█▊         | 1010/6332 [6:39:45<2360:33:57, 1596.78s/it][A
epoch:0 loss:149.6234130859375:  16%|█▊         | 1011/6332 [6:39:45<1659:50:38, 1122.99s/it][A
epoch:0 loss:164.5653839111328:  16%|█▊         | 1011/6332 [6:40:03<1659:50:38, 1122.99s/it][A
epoch:0 loss:164.5653839111328:  16%|█▉          | 1012/6332 [6:40:03<1169:29:35, 791.39s/it][A
epoch:0 loss:153.9185791015625:  16%|█▉          | 1012/6332 [6:49:46<1169:29:35, 791.39s/it][A
epoch:0 loss:153.9185791015625

epoch:0 loss:171.3257293701172:  17%|██▌            | 1092/6332 [7:58:03<20:57:16, 14.40s/it][A
epoch:0 loss:139.086669921875:  17%|██▊             | 1092/6332 [7:58:20<20:57:16, 14.40s/it][A
epoch:0 loss:139.086669921875:  17%|██▊             | 1093/6332 [7:58:20<22:03:00, 15.15s/it][A
epoch:0 loss:117.99024200439453:  17%|██▍           | 1093/6332 [7:58:31<22:03:00, 15.15s/it][A
epoch:0 loss:117.99024200439453:  17%|██▍           | 1094/6332 [7:58:31<20:31:40, 14.11s/it][A
epoch:0 loss:235.39324951171875:  17%|██▍           | 1094/6332 [7:58:44<20:31:40, 14.11s/it][A
epoch:0 loss:235.39324951171875:  17%|██▍           | 1095/6332 [7:58:44<19:42:55, 13.55s/it][A
epoch:0 loss:135.1220245361328:  17%|██▌            | 1095/6332 [7:59:01<19:42:55, 13.55s/it][A
epoch:0 loss:135.1220245361328:  17%|██▌            | 1096/6332 [7:59:01<21:30:56, 14.79s/it][A
epoch:0 loss:180.49940490722656:  17%|██▍           | 1096/6332 [7:59:18<21:30:56, 14.79s/it][A
epoch:0 loss:180.4994049072265

epoch:0 loss:168.911865234375:  19%|██▉             | 1176/6332 [8:17:33<20:03:38, 14.01s/it][A
epoch:0 loss:118.10992431640625:  19%|██▌           | 1176/6332 [8:17:49<20:03:38, 14.01s/it][A
epoch:0 loss:118.10992431640625:  19%|██▌           | 1177/6332 [8:17:49<20:53:40, 14.59s/it][A
epoch:0 loss:107.87940979003906:  19%|██▌           | 1177/6332 [8:18:01<20:53:40, 14.59s/it][A
epoch:0 loss:107.87940979003906:  19%|██▌           | 1178/6332 [8:18:01<19:52:38, 13.88s/it][A
epoch:0 loss:120.66273498535156:  19%|██▌           | 1178/6332 [8:18:17<19:52:38, 13.88s/it][A
epoch:0 loss:120.66273498535156:  19%|██▌           | 1179/6332 [8:18:17<20:38:19, 14.42s/it][A
epoch:0 loss:150.9234619140625:  19%|██▊            | 1179/6332 [8:18:29<20:38:19, 14.42s/it][A
epoch:0 loss:150.9234619140625:  19%|██▊            | 1180/6332 [8:18:29<19:31:41, 13.65s/it][A
epoch:0 loss:113.91613006591797:  19%|██▌           | 1180/6332 [8:18:40<19:31:41, 13.65s/it][A
epoch:0 loss:113.9161300659179

epoch:0 loss:129.53073120117188:  20%|██▊           | 1260/6332 [8:35:12<18:12:24, 12.92s/it][A
epoch:0 loss:118.25001525878906:  20%|██▊           | 1260/6332 [8:35:26<18:12:24, 12.92s/it][A
epoch:0 loss:118.25001525878906:  20%|██▊           | 1261/6332 [8:35:26<18:36:39, 13.21s/it][A
epoch:0 loss:124.95378112792969:  20%|██▊           | 1261/6332 [8:35:39<18:36:39, 13.21s/it][A
epoch:0 loss:124.95378112792969:  20%|██▊           | 1262/6332 [8:35:39<18:20:56, 13.03s/it][A
epoch:0 loss:128.82937622070312:  20%|██▊           | 1262/6332 [8:35:50<18:20:56, 13.03s/it][A
epoch:0 loss:128.82937622070312:  20%|██▊           | 1263/6332 [8:35:50<17:40:46, 12.56s/it][A
epoch:0 loss:177.4051971435547:  20%|██▉            | 1263/6332 [8:36:02<17:40:46, 12.56s/it][A
epoch:0 loss:177.4051971435547:  20%|██▉            | 1264/6332 [8:36:02<17:19:32, 12.31s/it][A
epoch:0 loss:126.30663299560547:  20%|██▊           | 1264/6332 [8:36:14<17:19:32, 12.31s/it][A
epoch:0 loss:126.3066329956054

epoch:0 loss:143.48448181152344:  21%|██▎        | 1344/6332 [20:56:50<180:31:09, 130.29s/it][A
epoch:0 loss:91.64515686035156:  21%|██▌         | 1344/6332 [20:57:02<180:31:09, 130.29s/it][A
epoch:0 loss:91.64515686035156:  21%|██▊          | 1345/6332 [20:57:02<131:27:33, 94.90s/it][A
epoch:0 loss:166.1232452392578:  21%|██▊          | 1345/6332 [20:57:13<131:27:33, 94.90s/it][A
epoch:0 loss:166.1232452392578:  21%|██▉           | 1346/6332 [20:57:13<96:28:11, 69.65s/it][A
epoch:0 loss:109.95735168457031:  21%|██▊          | 1346/6332 [20:57:27<96:28:11, 69.65s/it][A
epoch:0 loss:109.95735168457031:  21%|██▊          | 1347/6332 [20:57:27<73:13:03, 52.88s/it][A
epoch:0 loss:165.6008758544922:  21%|██▉           | 1347/6332 [20:57:40<73:13:03, 52.88s/it][A
epoch:0 loss:165.6008758544922:  21%|██▉           | 1348/6332 [20:57:40<56:52:02, 41.08s/it][A
epoch:0 loss:120.0954818725586:  21%|██▉           | 1348/6332 [20:57:53<56:52:02, 41.08s/it][A
epoch:0 loss:120.0954818725586

epoch:0 loss:144.48974609375:  23%|███▌            | 1428/6332 [21:17:41<17:44:33, 13.02s/it][A
epoch:0 loss:86.01618957519531:  23%|███▏          | 1428/6332 [21:17:55<17:44:33, 13.02s/it][A
epoch:0 loss:86.01618957519531:  23%|███▏          | 1429/6332 [21:17:55<18:06:01, 13.29s/it][A
epoch:0 loss:137.7411346435547:  23%|███▏          | 1429/6332 [21:18:09<18:06:01, 13.29s/it][A
epoch:0 loss:137.7411346435547:  23%|███▏          | 1430/6332 [21:18:09<18:14:58, 13.40s/it][A
epoch:0 loss:121.12491607666016:  23%|██▉          | 1430/6332 [21:18:22<18:14:58, 13.40s/it][A
epoch:0 loss:121.12491607666016:  23%|██▉          | 1431/6332 [21:18:22<18:15:04, 13.41s/it][A
epoch:0 loss:187.47267150878906:  23%|██▉          | 1431/6332 [21:18:37<18:15:04, 13.41s/it][A
epoch:0 loss:187.47267150878906:  23%|██▉          | 1432/6332 [21:18:37<18:55:29, 13.90s/it][A
epoch:0 loss:106.2619857788086:  23%|███▏          | 1432/6332 [21:18:49<18:55:29, 13.90s/it][A
epoch:0 loss:106.2619857788086

epoch:0 loss:109.01697540283203:  24%|███          | 1512/6332 [21:36:43<18:06:04, 13.52s/it][A
epoch:0 loss:180.90440368652344:  24%|███          | 1512/6332 [21:36:55<18:06:04, 13.52s/it][A
epoch:0 loss:180.90440368652344:  24%|███          | 1513/6332 [21:36:55<17:20:30, 12.96s/it][A
epoch:0 loss:111.48497772216797:  24%|███          | 1513/6332 [21:37:05<17:20:30, 12.96s/it][A
epoch:0 loss:111.48497772216797:  24%|███          | 1514/6332 [21:37:05<16:21:37, 12.22s/it][A
epoch:0 loss:117.67234802246094:  24%|███          | 1514/6332 [21:37:18<16:21:37, 12.22s/it][A
epoch:0 loss:117.67234802246094:  24%|███          | 1515/6332 [21:37:18<16:35:02, 12.39s/it][A
epoch:0 loss:127.16661071777344:  24%|███          | 1515/6332 [21:37:32<16:35:02, 12.39s/it][A
epoch:0 loss:127.16661071777344:  24%|███          | 1516/6332 [21:37:32<17:02:50, 12.74s/it][A
epoch:0 loss:136.79067993164062:  24%|███          | 1516/6332 [21:37:48<17:02:50, 12.74s/it][A
epoch:0 loss:136.7906799316406

epoch:0 loss:145.1471405029297:  25%|███▌          | 1596/6332 [21:55:00<16:18:55, 12.40s/it][A
epoch:0 loss:105.7800521850586:  25%|███▌          | 1596/6332 [21:55:12<16:18:55, 12.40s/it][A
epoch:0 loss:105.7800521850586:  25%|███▌          | 1597/6332 [21:55:12<16:22:36, 12.45s/it][A
epoch:0 loss:129.534423828125:  25%|███▊           | 1597/6332 [21:55:26<16:22:36, 12.45s/it][A
epoch:0 loss:129.534423828125:  25%|███▊           | 1598/6332 [21:55:26<16:42:40, 12.71s/it][A
epoch:0 loss:96.59400939941406:  25%|███▌          | 1598/6332 [21:55:40<16:42:40, 12.71s/it][A
epoch:0 loss:96.59400939941406:  25%|███▌          | 1599/6332 [21:55:40<17:14:58, 13.12s/it][A
epoch:0 loss:94.15335845947266:  25%|███▌          | 1599/6332 [21:55:53<17:14:58, 13.12s/it][A
epoch:0 loss:94.15335845947266:  25%|███▌          | 1600/6332 [21:55:53<17:10:06, 13.06s/it][A
epoch:0 loss:99.29395294189453:  25%|███▌          | 1600/6332 [21:56:04<17:10:06, 13.06s/it][A
epoch:0 loss:99.29395294189453

epoch:0 loss:165.1075439453125:  27%|███▋          | 1680/6332 [22:12:27<15:57:37, 12.35s/it][A
epoch:0 loss:226.0991668701172:  27%|███▋          | 1680/6332 [22:12:39<15:57:37, 12.35s/it][A
epoch:0 loss:226.0991668701172:  27%|███▋          | 1681/6332 [22:12:39<16:01:41, 12.41s/it][A
epoch:0 loss:199.2105712890625:  27%|███▋          | 1681/6332 [22:12:53<16:01:41, 12.41s/it][A
epoch:0 loss:199.2105712890625:  27%|███▋          | 1682/6332 [22:12:53<16:20:33, 12.65s/it][A
epoch:0 loss:170.1804962158203:  27%|███▋          | 1682/6332 [22:13:05<16:20:33, 12.65s/it][A
epoch:0 loss:170.1804962158203:  27%|███▋          | 1683/6332 [22:13:05<16:20:17, 12.65s/it][A
epoch:0 loss:130.72642517089844:  27%|███▍         | 1683/6332 [22:13:18<16:20:17, 12.65s/it][A
epoch:0 loss:130.72642517089844:  27%|███▍         | 1684/6332 [22:13:18<16:21:32, 12.67s/it][A
epoch:0 loss:180.51568603515625:  27%|███▍         | 1684/6332 [22:13:31<16:21:32, 12.67s/it][A
epoch:0 loss:180.5156860351562

epoch:0 loss:140.56956481933594:  28%|███▌         | 1764/6332 [22:30:10<15:41:46, 12.37s/it][A
epoch:0 loss:118.29019165039062:  28%|███▌         | 1764/6332 [22:30:23<15:41:46, 12.37s/it][A
epoch:0 loss:118.29019165039062:  28%|███▌         | 1765/6332 [22:30:23<15:57:45, 12.58s/it][A
epoch:0 loss:170.37693786621094:  28%|███▌         | 1765/6332 [22:30:37<15:57:45, 12.58s/it][A
epoch:0 loss:170.37693786621094:  28%|███▋         | 1766/6332 [22:30:37<16:27:25, 12.98s/it][A
epoch:0 loss:150.2075653076172:  28%|███▉          | 1766/6332 [22:30:51<16:27:25, 12.98s/it][A
epoch:0 loss:150.2075653076172:  28%|███▉          | 1767/6332 [22:30:51<16:40:00, 13.14s/it][A
epoch:0 loss:112.6707534790039:  28%|███▉          | 1767/6332 [22:31:04<16:40:00, 13.14s/it][A
epoch:0 loss:112.6707534790039:  28%|███▉          | 1768/6332 [22:31:04<16:48:40, 13.26s/it][A
epoch:0 loss:152.5908203125:  28%|████▋            | 1768/6332 [22:31:17<16:48:40, 13.26s/it][A
epoch:0 loss:152.5908203125:  

epoch:0 loss:92.35100555419922:  29%|████          | 1848/6332 [22:48:24<16:46:32, 13.47s/it][A
epoch:0 loss:109.38639068603516:  29%|███▊         | 1848/6332 [22:48:35<16:46:32, 13.47s/it][A
epoch:0 loss:109.38639068603516:  29%|███▊         | 1849/6332 [22:48:35<16:07:17, 12.95s/it][A
epoch:0 loss:145.51473999023438:  29%|███▊         | 1849/6332 [22:48:47<16:07:17, 12.95s/it][A
epoch:0 loss:145.51473999023438:  29%|███▊         | 1850/6332 [22:48:47<15:39:37, 12.58s/it][A
epoch:0 loss:160.9121856689453:  29%|████          | 1850/6332 [22:48:58<15:39:37, 12.58s/it][A
epoch:0 loss:160.9121856689453:  29%|████          | 1851/6332 [22:48:58<14:57:30, 12.02s/it][A
epoch:0 loss:79.76412963867188:  29%|████          | 1851/6332 [22:49:09<14:57:30, 12.02s/it][A
epoch:0 loss:79.76412963867188:  29%|████          | 1852/6332 [22:49:09<14:48:08, 11.89s/it][A
epoch:0 loss:178.14407348632812:  29%|███▊         | 1852/6332 [22:49:23<14:48:08, 11.89s/it][A
epoch:0 loss:178.1440734863281

epoch:0 loss:96.2489013671875:  31%|████▌          | 1932/6332 [23:06:22<15:39:35, 12.81s/it][A
epoch:0 loss:123.60768127441406:  31%|███▉         | 1932/6332 [23:06:36<15:39:35, 12.81s/it][A
epoch:0 loss:123.60768127441406:  31%|███▉         | 1933/6332 [23:06:36<16:03:28, 13.14s/it][A
epoch:0 loss:94.21207427978516:  31%|████▎         | 1933/6332 [23:06:49<16:03:28, 13.14s/it][A
epoch:0 loss:94.21207427978516:  31%|████▎         | 1934/6332 [23:06:49<15:54:25, 13.02s/it][A
epoch:0 loss:162.08309936523438:  31%|███▉         | 1934/6332 [23:07:03<15:54:25, 13.02s/it][A
epoch:0 loss:162.08309936523438:  31%|███▉         | 1935/6332 [23:07:03<16:11:57, 13.26s/it][A
epoch:0 loss:116.66341400146484:  31%|███▉         | 1935/6332 [23:07:14<16:11:57, 13.26s/it][A
epoch:0 loss:116.66341400146484:  31%|███▉         | 1936/6332 [23:07:14<15:29:01, 12.68s/it][A
epoch:0 loss:147.04893493652344:  31%|███▉         | 1936/6332 [23:07:25<15:29:01, 12.68s/it][A
epoch:0 loss:147.0489349365234

epoch:0 loss:117.6636962890625:  32%|████▍         | 2016/6332 [23:24:19<15:12:07, 12.68s/it][A
epoch:0 loss:110.18415832519531:  32%|████▏        | 2016/6332 [23:24:33<15:12:07, 12.68s/it][A
epoch:0 loss:110.18415832519531:  32%|████▏        | 2017/6332 [23:24:33<15:43:47, 13.12s/it][A
epoch:0 loss:108.13462829589844:  32%|████▏        | 2017/6332 [23:24:46<15:43:47, 13.12s/it][A
epoch:0 loss:108.13462829589844:  32%|████▏        | 2018/6332 [23:24:46<15:57:29, 13.32s/it][A
epoch:0 loss:132.57093811035156:  32%|████▏        | 2018/6332 [23:24:59<15:57:29, 13.32s/it][A
epoch:0 loss:132.57093811035156:  32%|████▏        | 2019/6332 [23:24:59<15:48:02, 13.19s/it][A
epoch:0 loss:87.75726318359375:  32%|████▍         | 2019/6332 [23:25:12<15:48:02, 13.19s/it][A
epoch:0 loss:87.75726318359375:  32%|████▍         | 2020/6332 [23:25:12<15:38:26, 13.06s/it][A
epoch:0 loss:124.71192932128906:  32%|████▏        | 2020/6332 [23:25:23<15:38:26, 13.06s/it][A
epoch:0 loss:124.7119293212890

epoch:0 loss:118.29022979736328:  33%|████▎        | 2100/6332 [23:42:07<16:12:08, 13.78s/it][A
epoch:0 loss:160.67384338378906:  33%|████▎        | 2100/6332 [23:42:19<16:12:08, 13.78s/it][A
epoch:0 loss:160.67384338378906:  33%|████▎        | 2101/6332 [23:42:19<15:29:20, 13.18s/it][A
epoch:0 loss:121.3001937866211:  33%|████▋         | 2101/6332 [23:42:33<15:29:20, 13.18s/it][A
epoch:0 loss:121.3001937866211:  33%|████▋         | 2102/6332 [23:42:33<15:31:20, 13.21s/it][A
epoch:0 loss:131.95875549316406:  33%|████▎        | 2102/6332 [23:42:45<15:31:20, 13.21s/it][A
epoch:0 loss:131.95875549316406:  33%|████▎        | 2103/6332 [23:42:45<15:05:57, 12.85s/it][A
epoch:0 loss:110.32357788085938:  33%|████▎        | 2103/6332 [23:42:57<15:05:57, 12.85s/it][A
epoch:0 loss:110.32357788085938:  33%|████▎        | 2104/6332 [23:42:57<14:59:22, 12.76s/it][A
epoch:0 loss:136.09109497070312:  33%|████▎        | 2104/6332 [23:43:08<14:59:22, 12.76s/it][A
epoch:0 loss:136.0910949707031

epoch:0 loss:107.72915649414062:  34%|████▍        | 2184/6332 [23:59:51<14:45:28, 12.81s/it][A
epoch:0 loss:109.88798522949219:  34%|████▍        | 2184/6332 [24:00:03<14:45:28, 12.81s/it][A
epoch:0 loss:109.88798522949219:  35%|████▍        | 2185/6332 [24:00:03<14:42:14, 12.76s/it][A
epoch:0 loss:142.80447387695312:  35%|████▍        | 2185/6332 [24:00:14<14:42:14, 12.76s/it][A
epoch:0 loss:142.80447387695312:  35%|████▍        | 2186/6332 [24:00:14<14:00:49, 12.17s/it][A
epoch:0 loss:89.54278564453125:  35%|████▊         | 2186/6332 [24:00:26<14:00:49, 12.17s/it][A
epoch:0 loss:89.54278564453125:  35%|████▊         | 2187/6332 [24:00:26<13:58:59, 12.14s/it][A
epoch:0 loss:87.51319122314453:  35%|████▊         | 2187/6332 [24:00:38<13:58:59, 12.14s/it][A
epoch:0 loss:87.51319122314453:  35%|████▊         | 2188/6332 [24:00:38<13:52:34, 12.05s/it][A
epoch:0 loss:170.91845703125:  35%|█████▌          | 2188/6332 [24:00:52<13:52:34, 12.05s/it][A
epoch:0 loss:170.91845703125: 

epoch:0 loss:166.21961975097656:  36%|████▋        | 2268/6332 [24:17:42<13:45:41, 12.19s/it][A
epoch:0 loss:106.37609100341797:  36%|████▋        | 2268/6332 [24:17:54<13:45:41, 12.19s/it][A
epoch:0 loss:106.37609100341797:  36%|████▋        | 2269/6332 [24:17:54<13:38:08, 12.08s/it][A
epoch:0 loss:126.1471939086914:  36%|█████         | 2269/6332 [24:18:04<13:38:08, 12.08s/it][A
epoch:0 loss:126.1471939086914:  36%|█████         | 2270/6332 [24:18:04<13:06:44, 11.62s/it][A
epoch:0 loss:151.90447998046875:  36%|████▋        | 2270/6332 [24:18:15<13:06:44, 11.62s/it][A
epoch:0 loss:151.90447998046875:  36%|████▋        | 2271/6332 [24:18:15<12:52:11, 11.41s/it][A
epoch:0 loss:125.24474334716797:  36%|████▋        | 2271/6332 [24:18:30<12:52:11, 11.41s/it][A
epoch:0 loss:125.24474334716797:  36%|████▋        | 2272/6332 [24:18:30<13:59:06, 12.40s/it][A
epoch:0 loss:133.8973846435547:  36%|█████         | 2272/6332 [24:18:43<13:59:06, 12.40s/it][A
epoch:0 loss:133.8973846435547

epoch:0 loss:83.07059478759766:  37%|█████▏        | 2352/6332 [24:35:52<13:41:54, 12.39s/it][A
epoch:0 loss:94.78128814697266:  37%|█████▏        | 2352/6332 [24:36:04<13:41:54, 12.39s/it][A
epoch:0 loss:94.78128814697266:  37%|█████▏        | 2353/6332 [24:36:04<13:33:06, 12.26s/it][A
epoch:0 loss:101.66119384765625:  37%|████▊        | 2353/6332 [24:36:16<13:33:06, 12.26s/it][A
epoch:0 loss:101.66119384765625:  37%|████▊        | 2354/6332 [24:36:16<13:43:05, 12.41s/it][A
epoch:0 loss:114.16592407226562:  37%|████▊        | 2354/6332 [24:36:28<13:43:05, 12.41s/it][A
epoch:0 loss:114.16592407226562:  37%|████▊        | 2355/6332 [24:36:28<13:31:25, 12.24s/it][A
epoch:0 loss:150.8584442138672:  37%|█████▏        | 2355/6332 [24:36:41<13:31:25, 12.24s/it][A
epoch:0 loss:150.8584442138672:  37%|█████▏        | 2356/6332 [24:36:41<13:45:24, 12.46s/it][A
epoch:0 loss:93.46526336669922:  37%|█████▏        | 2356/6332 [24:36:54<13:45:24, 12.46s/it][A
epoch:0 loss:93.46526336669922

epoch:0 loss:120.18956756591797:  38%|█████        | 2436/6332 [24:53:19<13:58:42, 12.92s/it][A
epoch:0 loss:181.3500213623047:  38%|█████▍        | 2436/6332 [24:53:31<13:58:42, 12.92s/it][A
epoch:0 loss:181.3500213623047:  38%|█████▍        | 2437/6332 [24:53:31<13:33:13, 12.53s/it][A
epoch:0 loss:148.2768096923828:  38%|█████▍        | 2437/6332 [24:53:44<13:33:13, 12.53s/it][A
epoch:0 loss:148.2768096923828:  39%|█████▍        | 2438/6332 [24:53:44<13:47:05, 12.74s/it][A
epoch:0 loss:122.79964447021484:  39%|█████        | 2438/6332 [24:53:57<13:47:05, 12.74s/it][A
epoch:0 loss:122.79964447021484:  39%|█████        | 2439/6332 [24:53:57<13:50:58, 12.81s/it][A
epoch:0 loss:149.56356811523438:  39%|█████        | 2439/6332 [24:54:10<13:50:58, 12.81s/it][A
epoch:0 loss:149.56356811523438:  39%|█████        | 2440/6332 [24:54:10<13:48:25, 12.77s/it][A
epoch:0 loss:103.16197204589844:  39%|█████        | 2440/6332 [24:54:23<13:48:25, 12.77s/it][A
epoch:0 loss:103.1619720458984

epoch:0 loss:162.4961395263672:  40%|█████▌        | 2520/6332 [25:11:01<13:18:46, 12.57s/it][A
epoch:0 loss:167.1260528564453:  40%|█████▌        | 2520/6332 [25:11:13<13:18:46, 12.57s/it][A
epoch:0 loss:167.1260528564453:  40%|█████▌        | 2521/6332 [25:11:13<13:12:44, 12.48s/it][A
epoch:0 loss:103.95099639892578:  40%|█████▏       | 2521/6332 [25:11:27<13:12:44, 12.48s/it][A
epoch:0 loss:103.95099639892578:  40%|█████▏       | 2522/6332 [25:11:27<13:27:47, 12.72s/it][A
epoch:0 loss:129.11456298828125:  40%|█████▏       | 2522/6332 [25:11:39<13:27:47, 12.72s/it][A
epoch:0 loss:129.11456298828125:  40%|█████▏       | 2523/6332 [25:11:39<13:15:25, 12.53s/it][A
epoch:0 loss:91.24633026123047:  40%|█████▌        | 2523/6332 [25:11:52<13:15:25, 12.53s/it][A
epoch:0 loss:91.24633026123047:  40%|█████▌        | 2524/6332 [25:11:52<13:24:26, 12.68s/it][A
epoch:0 loss:117.91287994384766:  40%|█████▏       | 2524/6332 [25:12:04<13:24:26, 12.68s/it][A
epoch:0 loss:117.9128799438476

epoch:0 loss:108.85171508789062:  41%|█████▎       | 2604/6332 [25:28:40<12:42:37, 12.27s/it][A
epoch:0 loss:133.3568115234375:  41%|█████▊        | 2604/6332 [25:28:53<12:42:37, 12.27s/it][A
epoch:0 loss:133.3568115234375:  41%|█████▊        | 2605/6332 [25:28:53<12:53:42, 12.46s/it][A
epoch:0 loss:142.1723175048828:  41%|█████▊        | 2605/6332 [25:29:05<12:53:42, 12.46s/it][A
epoch:0 loss:142.1723175048828:  41%|█████▊        | 2606/6332 [25:29:05<12:53:11, 12.45s/it][A
epoch:0 loss:104.34475708007812:  41%|█████▎       | 2606/6332 [25:29:19<12:53:11, 12.45s/it][A
epoch:0 loss:104.34475708007812:  41%|█████▎       | 2607/6332 [25:29:19<13:13:14, 12.78s/it][A
epoch:0 loss:141.3397979736328:  41%|█████▊        | 2607/6332 [25:29:33<13:13:14, 12.78s/it][A
epoch:0 loss:141.3397979736328:  41%|█████▊        | 2608/6332 [25:29:33<13:37:48, 13.18s/it][A
epoch:0 loss:174.23191833496094:  41%|█████▎       | 2608/6332 [25:29:45<13:37:48, 13.18s/it][A
epoch:0 loss:174.2319183349609

epoch:0 loss:101.41455078125:  42%|██████▊         | 2688/6332 [25:46:26<12:51:32, 12.70s/it][A
epoch:0 loss:122.67154693603516:  42%|█████▌       | 2688/6332 [25:46:39<12:51:32, 12.70s/it][A
epoch:0 loss:122.67154693603516:  42%|█████▌       | 2689/6332 [25:46:39<13:00:54, 12.86s/it][A
epoch:0 loss:98.64639282226562:  42%|█████▉        | 2689/6332 [25:46:52<13:00:54, 12.86s/it][A
epoch:0 loss:98.64639282226562:  42%|█████▉        | 2690/6332 [25:46:52<12:57:06, 12.80s/it][A
epoch:0 loss:107.0369873046875:  42%|█████▉        | 2690/6332 [25:47:04<12:57:06, 12.80s/it][A
epoch:0 loss:107.0369873046875:  42%|█████▉        | 2691/6332 [25:47:04<12:47:12, 12.64s/it][A
epoch:0 loss:121.56552124023438:  42%|█████▌       | 2691/6332 [25:47:17<12:47:12, 12.64s/it][A
epoch:0 loss:121.56552124023438:  43%|█████▌       | 2692/6332 [25:47:17<12:42:13, 12.56s/it][A
epoch:0 loss:144.7705078125:  43%|███████▏         | 2692/6332 [25:47:30<12:42:13, 12.56s/it][A
epoch:0 loss:144.7705078125:  

epoch:0 loss:159.43603515625:  44%|███████         | 2772/6332 [26:04:10<14:17:41, 14.46s/it][A
epoch:0 loss:86.66825866699219:  44%|██████▏       | 2772/6332 [26:04:24<14:17:41, 14.46s/it][A
epoch:0 loss:86.66825866699219:  44%|██████▏       | 2773/6332 [26:04:24<14:03:51, 14.23s/it][A
epoch:0 loss:71.6283950805664:  44%|██████▌        | 2773/6332 [26:04:35<14:03:51, 14.23s/it][A
epoch:0 loss:71.6283950805664:  44%|██████▌        | 2774/6332 [26:04:35<13:14:01, 13.39s/it][A
epoch:0 loss:76.46056365966797:  44%|██████▏       | 2774/6332 [26:04:48<13:14:01, 13.39s/it][A
epoch:0 loss:76.46056365966797:  44%|██████▏       | 2775/6332 [26:04:48<13:06:27, 13.27s/it][A
epoch:0 loss:124.2137451171875:  44%|██████▏       | 2775/6332 [26:05:03<13:06:27, 13.27s/it][A
epoch:0 loss:124.2137451171875:  44%|██████▏       | 2776/6332 [26:05:03<13:41:51, 13.87s/it][A
epoch:0 loss:107.43246459960938:  44%|█████▋       | 2776/6332 [26:05:20<13:41:51, 13.87s/it][A
epoch:0 loss:107.4324645996093

epoch:0 loss:104.5057601928711:  45%|██████▎       | 2856/6332 [26:22:19<13:01:07, 13.48s/it][A
epoch:0 loss:120.43685913085938:  45%|█████▊       | 2856/6332 [26:22:37<13:01:07, 13.48s/it][A
epoch:0 loss:120.43685913085938:  45%|█████▊       | 2857/6332 [26:22:37<14:13:34, 14.74s/it][A
epoch:0 loss:101.84324645996094:  45%|█████▊       | 2857/6332 [26:22:55<14:13:34, 14.74s/it][A
epoch:0 loss:101.84324645996094:  45%|█████▊       | 2858/6332 [26:22:55<15:02:51, 15.59s/it][A
epoch:0 loss:99.30322265625:  45%|███████▋         | 2858/6332 [26:23:11<15:02:51, 15.59s/it][A
epoch:0 loss:99.30322265625:  45%|███████▋         | 2859/6332 [26:23:11<15:18:10, 15.86s/it][A
epoch:0 loss:116.04847717285156:  45%|█████▊       | 2859/6332 [26:23:22<15:18:10, 15.86s/it][A
epoch:0 loss:116.04847717285156:  45%|█████▊       | 2860/6332 [26:23:22<13:50:16, 14.35s/it][A
epoch:0 loss:112.2630615234375:  45%|██████▎       | 2860/6332 [26:23:34<13:50:16, 14.35s/it][A
epoch:0 loss:112.2630615234375

epoch:0 loss:140.02281188964844:  46%|██████       | 2940/6332 [26:41:43<14:25:02, 15.30s/it][A
epoch:0 loss:107.95397186279297:  46%|██████       | 2940/6332 [26:41:57<14:25:02, 15.30s/it][A
epoch:0 loss:107.95397186279297:  46%|██████       | 2941/6332 [26:41:57<14:02:34, 14.91s/it][A
epoch:0 loss:193.36431884765625:  46%|██████       | 2941/6332 [26:42:13<14:02:34, 14.91s/it][A
epoch:0 loss:193.36431884765625:  46%|██████       | 2942/6332 [26:42:13<14:21:47, 15.25s/it][A
epoch:0 loss:106.70806884765625:  46%|██████       | 2942/6332 [26:42:30<14:21:47, 15.25s/it][A
epoch:0 loss:106.70806884765625:  46%|██████       | 2943/6332 [26:42:30<14:41:54, 15.61s/it][A
epoch:0 loss:143.96287536621094:  46%|██████       | 2943/6332 [26:42:41<14:41:54, 15.61s/it][A
epoch:0 loss:143.96287536621094:  46%|██████       | 2944/6332 [26:42:41<13:19:40, 14.16s/it][A
epoch:0 loss:146.81207275390625:  46%|██████       | 2944/6332 [26:42:57<13:19:40, 14.16s/it][A
epoch:0 loss:146.8120727539062

epoch:0 loss:122.74517822265625:  48%|██████▏      | 3024/6332 [27:00:31<12:34:32, 13.69s/it][A
epoch:0 loss:142.17745971679688:  48%|██████▏      | 3024/6332 [27:00:43<12:34:32, 13.69s/it][A
epoch:0 loss:142.17745971679688:  48%|██████▏      | 3025/6332 [27:00:43<11:56:17, 13.00s/it][A
epoch:0 loss:107.7353515625:  48%|████████         | 3025/6332 [27:00:55<11:56:17, 13.00s/it][A
epoch:0 loss:107.7353515625:  48%|████████         | 3026/6332 [27:00:55<11:53:00, 12.94s/it][A
epoch:0 loss:128.86268615722656:  48%|██████▏      | 3026/6332 [27:01:06<11:53:00, 12.94s/it][A
epoch:0 loss:128.86268615722656:  48%|██████▏      | 3027/6332 [27:01:06<11:10:31, 12.17s/it][A
epoch:0 loss:165.3642120361328:  48%|██████▋       | 3027/6332 [27:01:17<11:10:31, 12.17s/it][A
epoch:0 loss:165.3642120361328:  48%|██████▋       | 3028/6332 [27:01:17<10:55:47, 11.91s/it][A
epoch:0 loss:126.95848846435547:  48%|██████▏      | 3028/6332 [27:01:28<10:55:47, 11.91s/it][A
epoch:0 loss:126.9584884643554

epoch:0 loss:96.48887634277344:  49%|██████▊       | 3108/6332 [27:19:10<12:17:20, 13.72s/it][A
epoch:0 loss:84.1299819946289:  49%|███████▎       | 3108/6332 [27:19:21<12:17:20, 13.72s/it][A
epoch:0 loss:84.1299819946289:  49%|███████▎       | 3109/6332 [27:19:21<11:35:58, 12.96s/it][A
epoch:0 loss:109.39341735839844:  49%|██████▍      | 3109/6332 [27:19:39<11:35:58, 12.96s/it][A
epoch:0 loss:109.39341735839844:  49%|██████▍      | 3110/6332 [27:19:39<12:55:07, 14.43s/it][A
epoch:0 loss:106.69674682617188:  49%|██████▍      | 3110/6332 [27:19:50<12:55:07, 14.43s/it][A
epoch:0 loss:106.69674682617188:  49%|██████▍      | 3111/6332 [27:19:50<12:02:53, 13.47s/it][A
epoch:0 loss:97.13688659667969:  49%|██████▉       | 3111/6332 [27:20:03<12:02:53, 13.47s/it][A
epoch:0 loss:97.13688659667969:  49%|██████▉       | 3112/6332 [27:20:03<11:50:45, 13.24s/it][A
epoch:0 loss:126.88949584960938:  49%|██████▍      | 3112/6332 [27:20:15<11:50:45, 13.24s/it][A
epoch:0 loss:126.8894958496093

epoch:0 loss:137.31442260742188:  50%|██████▌      | 3192/6332 [27:37:57<13:23:30, 15.35s/it][A
epoch:0 loss:86.01646423339844:  50%|███████       | 3192/6332 [27:38:11<13:23:30, 15.35s/it][A
epoch:0 loss:86.01646423339844:  50%|███████       | 3193/6332 [27:38:11<12:55:59, 14.83s/it][A
epoch:0 loss:79.46331787109375:  50%|███████       | 3193/6332 [27:38:27<12:55:59, 14.83s/it][A
epoch:0 loss:79.46331787109375:  50%|███████       | 3194/6332 [27:38:27<13:12:07, 15.15s/it][A
epoch:0 loss:127.748291015625:  50%|███████▌       | 3194/6332 [27:38:45<13:12:07, 15.15s/it][A
epoch:0 loss:127.748291015625:  50%|███████▌       | 3195/6332 [27:38:45<14:00:41, 16.08s/it][A
epoch:0 loss:114.38103485107422:  50%|██████▌      | 3195/6332 [27:39:01<14:00:41, 16.08s/it][A
epoch:0 loss:114.38103485107422:  50%|██████▌      | 3196/6332 [27:39:01<13:56:27, 16.00s/it][A
epoch:0 loss:123.91590881347656:  50%|██████▌      | 3196/6332 [27:39:15<13:56:27, 16.00s/it][A
epoch:0 loss:123.9159088134765

epoch:0 loss:135.328125:  52%|██████████▊          | 3276/6332 [30:40:59<61:57:41, 72.99s/it][A
epoch:0 loss:77.83094024658203:  52%|███████▏      | 3276/6332 [30:41:11<61:57:41, 72.99s/it][A
epoch:0 loss:77.83094024658203:  52%|███████▏      | 3277/6332 [30:41:11<46:35:39, 54.91s/it][A
epoch:0 loss:124.02973937988281:  52%|██████▋      | 3277/6332 [30:41:22<46:35:39, 54.91s/it][A
epoch:0 loss:124.02973937988281:  52%|██████▋      | 3278/6332 [30:41:22<35:23:28, 41.72s/it][A
epoch:0 loss:82.74150085449219:  52%|███████▏      | 3278/6332 [30:53:11<35:23:28, 41.72s/it][A
epoch:0 loss:82.74150085449219:  52%|██████▏     | 3279/6332 [30:53:11<205:08:15, 241.89s/it][A
epoch:0 loss:90.31734466552734:  52%|██████▏     | 3279/6332 [30:53:24<205:08:15, 241.89s/it][A
epoch:0 loss:90.31734466552734:  52%|██████▏     | 3280/6332 [30:53:24<146:50:12, 173.20s/it][A
epoch:0 loss:207.40643310546875:  52%|█████▋     | 3280/6332 [30:53:37<146:50:12, 173.20s/it][A
epoch:0 loss:207.4064331054687

KeyboardInterrupt: 

In [None]:
# pyplot의 이미지 형식에 맞추기 위한 약간의 변형합니다.
test_L, test_AB = dataset[0]
test_L = np.expand_dims(test_L, axis=0)
real_img = np.concatenate([test_L, test_AB])
real_img = real_img.transpose(1, 2, 0).astype(np.uint8)
real_img = lab2rgb(real_img)

with torch.no_grad():
    model.load_state_dict(torch.load("AutoColor.pth", map_location=device))

    input_tensor = torch.tensor(test_L)
    input_tensor = torch.unsqueeze(input_tensor, dim=0).to(device)
    pred_AB = model(input_tensor)
   
    pred_LAB = torch.cat([input_tensor, pred_AB], dim=1)
    pred_LAB = torch.squeeze(pred_LAB)
    pred_LAB = pred_LAB.permute(1, 2, 0).cpu().numpy()
    pred_LAB = lab2rgb(pred_LAB.astype(np.uint8))

plt.subplot(1, 2, 1)
plt.imshow(real_img)
plt.title("real image")
plt.subplot(1, 2, 2)
plt.imshow(pred_LAB)
plt.title("predicted image")
plt.show()