In [2]:
import os
import glob
import cv2
import imageio
import base64
import time
import ipywidgets as widgets
from IPython.display import display
import io
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import csv

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from fastai.data.external import untar_data, URLs
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
seed = 123
SIZE = 256
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

In [4]:

image_files = glob.glob(coco_path + "/*.jpg")
num_samples = len(image_files)
print(f"Total number of images available: {num_samples}")

paths_subset = np.array(image_files)
rand_idxs = np.random.permutation(num_samples)  # Tạo ra một hoán vị ngẫu nhiên của tất cả các chỉ số

train_idxs = rand_idxs[:int(num_samples * 0.95)]
val_idxs = rand_idxs[int(num_samples * 0.95):]

train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]

print(f"Number of training images: {len(train_paths)}")
print(f"Number of validation images: {len(val_paths)}")

class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)

        self.split = split
        self.size = SIZE
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1

        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet18_model = resnet18(pretrained=True)  # Khởi tạo model resnet18
    body = create_body(resnet18_model, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)

    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

def init_weights(net, init='norm', gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

class MainModel(nn.Module):
    def __init__(self, net_G, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

checkpoint_dir = "/content/drive/MyDrive/Colab Notebooks/tin"
os.makedirs(checkpoint_dir, exist_ok=True)

def train_model(model, train_dl, start_epochs, end_epochs, display_every=10, csv_filename="training_losses.csv"):
    # Initialize variables for best checkpoint tracking
    best_loss = float('inf')
    best_checkpoint_path = None
    best_epoch = start_epochs

    # Check if the CSV file exists; if not, create it and write the header
    csv_exists = os.path.isfile(csv_filename)
    if not csv_exists:
        with open(csv_filename, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            # Write the header
            writer.writerow(['Epoch', 'Iteration', 'loss_D_fake', 'loss_D_real', 'loss_D', 'loss_G_GAN', 'loss_G_L1', 'loss_G'])

    for e in range(start_epochs, end_epochs + 1):
        loss_meter_dict = create_loss_meters()  # Dictionary to log loss values
        i = 0  # Batch counter

        for data in tqdm(train_dl):
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0))  # Update loss values
            i += 1

            # Display results every 'display_every' iterations
            if i % display_every == 0:
                #print(f"\nEpoch {e}/{end_epochs}")
                #print(f"Iteration {i}/{len(train_dl)}")
                #log_results(loss_meter_dict)  # Function to print loss values

                # Write the current losses to CSV
                with open(csv_filename, 'a', newline='') as csvfile:
                    writer = csv.writer(csvfile)
                    # Prepare the row data
                    row = [e, i]
                    for loss_name in ['loss_D_fake', 'loss_D_real', 'loss_D', 'loss_G_GAN', 'loss_G_L1', 'loss_G']:
                        row.append(loss_meter_dict[loss_name].avg)
                    writer.writerow(row)

        # At the end of each epoch, calculate the average loss
        avg_loss = loss_meter_dict['loss_G'].avg  # Use 'loss_G' as the total generator loss

        print(f"Epoch {e}/{end_epochs}")
        print(f"Average Loss: {avg_loss:.5f}")

        # Check if the current average loss is the best so far
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = e
            checkpoint_path =  os.path.join(checkpoint_dir, f"best_checkpoint_{best_epoch}.pt")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_G_state_dict': model.opt_G.state_dict(),
                'optimizer_D_state_dict': model.opt_D.state_dict(),
                'epoch': e,
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"New best checkpoint saved at {checkpoint_path} with loss {best_loss:.5f} at epoch {best_epoch}")

            # Remove the previous best checkpoint if it exists and is different from the current one
            if best_checkpoint_path is not None and os.path.exists(best_checkpoint_path) and best_checkpoint_path != checkpoint_path:
                os.remove(best_checkpoint_path)
                print(f"Previous best checkpoint {best_checkpoint_path} has been deleted.")

            best_checkpoint_path = checkpoint_path

generator_model_path = "/content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_78.pt"
net_G = build_res_unet(n_input=1, n_output=2, size=256)
res_checkpoint = torch.load(generator_model_path, map_location=device)
net_G.load_state_dict(res_checkpoint['model_state_dict'])

Total number of images available: 21837
Number of training images: 20745
Number of validation images: 1092


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 109MB/s]


<All keys matched successfully>

In [None]:
# checkpoint_path = "/kaggle/working/cor_best_checkpoint_1.pt"
# checkpoint = torch.load(checkpoint_path, map_location=device)
# model = MainModel(net_G=net_G)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.opt_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
# model.opt_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
# start_epoch = checkpoint['epoch'] + 1
# train_model(model, train_dl, start_epochs=start_epoch, end_epochs=2, display_every=30)

In [5]:
checkpoint_path = "/content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_29.pt"

In [None]:
model = MainModel(net_G=net_G)

# Load checkpoint nếu tồn tại
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)  # Load checkpoint

    model.load_state_dict(checkpoint['model_state_dict'])  # Đúng key
    model.opt_G.load_state_dict(checkpoint['optimizer_G_state_dict'])  # Đúng key
    model.opt_D.load_state_dict(checkpoint['optimizer_D_state_dict'])  # Đúng key
    # Lấy epoch đã train
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint["loss"]
    print(f"Checkpoint loaded! Resuming training from epoch {start_epoch}.")
else:
    start_epoch = 1  # Nếu không có checkpoint, bắt đầu từ epoch 1
    print("No checkpoint found, training from scratch.")

train_model(model, train_dl, start_epochs=start_epoch, end_epochs=65)

model initialized with norm initialization
Loading checkpoint from /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_29.pt...
Checkpoint loaded! Resuming training from epoch 30.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 30/65
Average Loss: 6.10701
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_30.pt with loss 6.10701 at epoch 30


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 31/65
Average Loss: 5.97302
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_31.pt with loss 5.97302 at epoch 31
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_30.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 32/65
Average Loss: 5.86703
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_32.pt with loss 5.86703 at epoch 32
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_31.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 33/65
Average Loss: 5.80044
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_33.pt with loss 5.80044 at epoch 33
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_32.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 34/65
Average Loss: 5.75743
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_34.pt with loss 5.75743 at epoch 34
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_33.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 35/65
Average Loss: 5.70469
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_35.pt with loss 5.70469 at epoch 35
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_34.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 36/65
Average Loss: 5.67186
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_36.pt with loss 5.67186 at epoch 36
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_35.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 37/65
Average Loss: 5.63207
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_37.pt with loss 5.63207 at epoch 37
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_36.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 38/65
Average Loss: 5.59405
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_38.pt with loss 5.59405 at epoch 38
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_37.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 39/65
Average Loss: 5.56241
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_39.pt with loss 5.56241 at epoch 39
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_38.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 40/65
Average Loss: 5.53419
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_40.pt with loss 5.53419 at epoch 40
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_39.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]

Epoch 41/65
Average Loss: 5.49755
New best checkpoint saved at /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_41.pt with loss 5.49755 at epoch 41
Previous best checkpoint /content/drive/MyDrive/Colab Notebooks/tin/best_checkpoint_40.pt has been deleted.


  0%|          | 0/1297 [00:00<?, ?it/s]