In [1]:
import os
import glob
import cv2
import imageio
import base64
import time
import ipywidgets as widgets
from IPython.display import display
import io
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet
from torch.cuda.amp import autocast, GradScaler

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
seed = 123
SIZE = 256
# coco_path = untar_data(URLs.COCO_SAMPLE)
# coco_path.ls()

In [None]:
import shutil
dataset_path = Path("/content/drive/MyDrive/Colab Notebooks/TTTN/Code/")
shutil.move(str(coco_path), dataset_path)
print("Dataset đã được di chuyển vào Google Drive:", dataset_path)

Dataset đã được di chuyển vào Google Drive: /content/drive/MyDrive/Colab Notebooks/TTTN/Code


In [4]:
coco_path = r'/content/drive/MyDrive/Colab Notebooks/TTTN/Code/coco_sample'

In [5]:
coco_path = str(coco_path) + "/train_sample"
image_files = glob.glob(coco_path + "/*.jpg")
num_samples = len(image_files)
print(f"Total number of images available: {num_samples}")

paths_subset = np.array(image_files)
rand_idxs = np.random.permutation(num_samples)

train_idxs = rand_idxs[:int(num_samples * 0.95)]
val_idxs = rand_idxs[int(num_samples * 0.95):]

train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]

print(f"Number of training images: {len(train_paths)}")
print(f"Number of validation images: {len(val_paths)}")

Total number of images available: 21837
Number of training images: 20745
Number of validation images: 1092


In [6]:
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)

        self.split = split
        self.size = SIZE
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1

        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet18_model = resnet18(pretrained=True)  # Khởi tạo model resnet18
    body = create_body(resnet18_model, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

def pretrain_generator(net_G, train_dl, opt, criterion, start_epoch, end_epochs):
    best_loss = float('inf')
    best_checkpoint_path = None
    best_epoch = start_epoch  # Biến để lưu epoch của best checkpoint

    for e in range(start_epoch, end_epochs):
        loss_meter = AverageMeter()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()

            loss_meter.update(loss.item(), L.size(0))

        current_epoch = e + 1
        print(f"Epoch {current_epoch}/{end_epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

        if loss_meter.avg < best_loss:
            best_loss = loss_meter.avg
            best_epoch = current_epoch
            checkpoint_path = f"best_checkpoint_{best_epoch}.pt"
            torch.save({
                'epoch': best_epoch,
                'model_state_dict': net_G.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': best_loss,
            }, checkpoint_path)
            print(f"New best checkpoint saved at {checkpoint_path} with loss {best_loss:.5f} at epoch {best_epoch}")
            if best_checkpoint_path is not None and os.path.exists(best_checkpoint_path) and best_checkpoint_path != checkpoint_path:
                os.remove(best_checkpoint_path)
                print(f"Previous best checkpoint {best_checkpoint_path} has been deleted.")
            best_checkpoint_path = checkpoint_path

net_G = build_res_unet(n_input=1, n_output=2, size=SIZE)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()
# pretrain_generator(net_G, train_dl, opt, criterion, start_epoch=0, end_epochs=100)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 165MB/s]


In [None]:
# checkpoint_path = "/kaggle/working/checkpoints/best_checkpoint.pt"
# checkpoint = torch.load(checkpoint_path, map_location=device)
# net_G = build_res_unet(n_input=1, n_output=2, size=SIZE)
# opt = optim.Adam(net_G.parameters(), lr=1e-4)
# net_G.load_state_dict(checkpoint['model_state_dict'])
# opt.load_state_dict(checkpoint['optimizer_state_dict'])
# start_epoch = checkpoint['epoch']

In [None]:
pretrain_generator(net_G, train_dl, opt, criterion, start_epoch=0, end_epochs=100)

  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 1/100
L1 Loss: 0.08365
New best checkpoint saved at best_checkpoint_1.pt with loss 0.08365 at epoch 1


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 2/100
L1 Loss: 0.08203
New best checkpoint saved at best_checkpoint_2.pt with loss 0.08203 at epoch 2
Previous best checkpoint best_checkpoint_1.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 3/100
L1 Loss: 0.08123
New best checkpoint saved at best_checkpoint_3.pt with loss 0.08123 at epoch 3
Previous best checkpoint best_checkpoint_2.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 4/100
L1 Loss: 0.08056
New best checkpoint saved at best_checkpoint_4.pt with loss 0.08056 at epoch 4
Previous best checkpoint best_checkpoint_3.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 5/100
L1 Loss: 0.07978
New best checkpoint saved at best_checkpoint_5.pt with loss 0.07978 at epoch 5
Previous best checkpoint best_checkpoint_4.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 6/100
L1 Loss: 0.07914
New best checkpoint saved at best_checkpoint_6.pt with loss 0.07914 at epoch 6
Previous best checkpoint best_checkpoint_5.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 7/100
L1 Loss: 0.07830
New best checkpoint saved at best_checkpoint_7.pt with loss 0.07830 at epoch 7
Previous best checkpoint best_checkpoint_6.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 8/100
L1 Loss: 0.07748
New best checkpoint saved at best_checkpoint_8.pt with loss 0.07748 at epoch 8
Previous best checkpoint best_checkpoint_7.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 9/100
L1 Loss: 0.07659
New best checkpoint saved at best_checkpoint_9.pt with loss 0.07659 at epoch 9
Previous best checkpoint best_checkpoint_8.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 10/100
L1 Loss: 0.07554
New best checkpoint saved at best_checkpoint_10.pt with loss 0.07554 at epoch 10
Previous best checkpoint best_checkpoint_9.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 11/100
L1 Loss: 0.07437
New best checkpoint saved at best_checkpoint_11.pt with loss 0.07437 at epoch 11
Previous best checkpoint best_checkpoint_10.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 12/100
L1 Loss: 0.07333
New best checkpoint saved at best_checkpoint_12.pt with loss 0.07333 at epoch 12
Previous best checkpoint best_checkpoint_11.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 13/100
L1 Loss: 0.07222
New best checkpoint saved at best_checkpoint_13.pt with loss 0.07222 at epoch 13
Previous best checkpoint best_checkpoint_12.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 14/100
L1 Loss: 0.07123
New best checkpoint saved at best_checkpoint_14.pt with loss 0.07123 at epoch 14
Previous best checkpoint best_checkpoint_13.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 15/100
L1 Loss: 0.07045
New best checkpoint saved at best_checkpoint_15.pt with loss 0.07045 at epoch 15
Previous best checkpoint best_checkpoint_14.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 16/100
L1 Loss: 0.06987
New best checkpoint saved at best_checkpoint_16.pt with loss 0.06987 at epoch 16
Previous best checkpoint best_checkpoint_15.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 17/100
L1 Loss: 0.06892
New best checkpoint saved at best_checkpoint_17.pt with loss 0.06892 at epoch 17
Previous best checkpoint best_checkpoint_16.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]