In [1]:
import os
import glob
import cv2
import imageio
import base64
import time
import ipywidgets as widgets
from IPython.display import display
import io
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet
from torch.cuda.amp import autocast, GradScaler

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
seed = 123
SIZE = 256
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

# Lấy danh sách tệp ảnh
image_files = glob.glob(coco_path + "/*.jpg")
num_samples = len(image_files)
print(f"Total number of images available: {num_samples}")

# Chia tập dữ liệu thành train/val
paths_subset = np.array(image_files)
rand_idxs = np.random.permutation(num_samples)
train_idxs = rand_idxs[:int(num_samples * 0.95)]
val_idxs = rand_idxs[int(num_samples * 0.95):]
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]

print(f"Number of training images: {len(train_paths)}")
print(f"Number of validation images: {len(val_paths)}")

Total number of images available: 21837
Number of training images: 20745
Number of validation images: 1092


In [8]:
# Định nghĩa Dataset
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ])
        else:
            self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)

        self.split = split
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")  # Chuyển RGB sang L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.  # Chuẩn hóa kênh L
        ab = img_lab[[1, 2], ...] / 110.  # Chuẩn hóa kênh ab
        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

# Hàm tạo DataLoader
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, pin_memory=pin_memory)
    return dataloader

train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

# Hàm theo dõi loss
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    return {
        'loss_D_fake': AverageMeter(),
        'loss_D_real': AverageMeter(),
        'loss_D': AverageMeter(),
        'loss_G_GAN': AverageMeter(),
        'loss_G_L1': AverageMeter(),
        'loss_G': AverageMeter()
    }

# Hàm chuyển đổi ảnh từ L*a*b về RGB
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = [lab2rgb(img) for img in Lab]
    return np.stack(rgb_imgs, axis=0)

# Hàm trực quan hóa kết quả
def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_imgs = lab_to_rgb(model.L, model.fake_color.detach())
    real_imgs = lab_to_rgb(model.L, model.ab)

    fig, axes = plt.subplots(3, 5, figsize=(15, 8))
    for i in range(5):
        axes[0, i].imshow(model.L[i][0].cpu(), cmap='gray')
        axes[1, i].imshow(fake_imgs[i])
        axes[2, i].imshow(real_imgs[i])
        for ax in axes[:, i]:
            ax.axis("off")
    plt.show()

    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

# Xây dựng mô hình U-Net dựa trên ResNet18
def build_res_unet(n_input=1, n_output=2, size=256):
    resnet18_model = resnet18(pretrained=True)
    body = create_body(resnet18_model, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

checkpoint_dir = "/content/drive/MyDrive/Colab Notebooks/ImageColorization"
os.makedirs(checkpoint_dir, exist_ok=True)

# Hàm huấn luyện
def pretrain_generator(net_G, train_dl, opt, criterion, start_epoch, end_epochs):
    best_loss = float('inf')
    best_checkpoint_path = None

    for epoch in range(start_epoch, end_epochs):
        loss_meter = AverageMeter()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_meter.update(loss.item(), L.size(0))

        print(f"Epoch {epoch}/{end_epochs - 1} - L1 Loss: {loss_meter.avg:.5f}")

        if loss_meter.avg < best_loss:
            best_loss = loss_meter.avg
            best_epoch = epoch
            checkpoint_path = os.path.join(checkpoint_dir, f"best_checkpoint_{best_epoch}.pt")
            torch.save({
                'epoch': best_epoch,
                'model_state_dict': net_G.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': best_loss,
            }, checkpoint_path)
            print(f"New best checkpoint saved at {checkpoint_path}")

            if best_checkpoint_path and os.path.exists(best_checkpoint_path):
                os.remove(best_checkpoint_path)
                print(f"Previous best checkpoint {best_checkpoint_path} deleted.")
            best_checkpoint_path = checkpoint_path

In [9]:
# Load checkpoint từ Google Drive
checkpoint_path = "/content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_43.pt"

In [10]:
net_G = build_res_unet(n_input=1, n_output=2, size=SIZE)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    net_G.load_state_dict(checkpoint["model_state_dict"])
    opt.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_loss = checkpoint["loss"]
    print(f"Resuming training from epoch {start_epoch}, best loss: {best_loss:.5f}")
else:
    print("Checkpoint not found! Starting from scratch...")
    start_epoch = 0

Loading checkpoint from /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_43.pt...
Resuming training from epoch 44, best loss: 0.05500


In [None]:
pretrain_generator(net_G, train_dl, opt, criterion, start_epoch=start_epoch, end_epochs=100)

  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 44/99 - L1 Loss: 0.05418
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_44.pt


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 45/99 - L1 Loss: 0.05343
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_45.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_44.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 46/99 - L1 Loss: 0.05304
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_46.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_45.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 47/99 - L1 Loss: 0.05288
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_47.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_46.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 48/99 - L1 Loss: 0.05267
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_48.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_47.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 49/99 - L1 Loss: 0.05232
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_49.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_48.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 50/99 - L1 Loss: 0.05180
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_50.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_49.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 51/99 - L1 Loss: 0.05134
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_51.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_50.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 52/99 - L1 Loss: 0.05096
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_52.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_51.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 53/99 - L1 Loss: 0.05050
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_53.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_52.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 54/99 - L1 Loss: 0.05010
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_54.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_53.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 55/99 - L1 Loss: 0.04990
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_55.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_54.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 56/99 - L1 Loss: 0.04951
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_56.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_55.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 57/99 - L1 Loss: 0.04893
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_57.pt
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/ImageColorization/best_checkpoint_56.pt deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]