In [None]:
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import patchify
import numpy as np
import matplotlib.gridspec as gridspec
import glob as glob
import os
import cv2
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from model.srcnn import SRCNN
# from model.sr_resnet import SRResNet
import math
from torchvision.utils import save_image
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 1
SHOW_PATCHES = True
STRIDE = 14
SIZE = 32

In [None]:
# !pip install patchify

In [None]:
def show_patches(patches):
    plt.figure(figsize=(patches.shape[0], patches.shape[1]))
    gs = gridspec.GridSpec(patches.shape[0], patches.shape[1])
    gs.update(wspace=0.01, hspace=0.02)
    counter = 0
    for i in range(patches.shape[0]):
        for j in range(patches.shape[1]):
            ax = plt.subplot(gs[counter])
            plt.imshow(patches[i, j, 0, :, :, :])
            plt.axis('off')
            counter += 1
    plt.show()

In [None]:
def psnr(label, outputs, max_val=1.):
    
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    diff = outputs - label
    rmse = math.sqrt(np.mean((diff) ** 2))
    if rmse == 0:
        return 100
    else:
        PSNR = 20 * math.log10(max_val / rmse)
        return PSNR

In [None]:
def save_plot(train_loss, val_loss, train_psnr, val_psnr):
    # Loss plots.
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', label='train loss')
    plt.plot(val_loss, color='red', label='validataion loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('/media/hero/Study/User/Study/data/model_out/image_out/loss.png')
    plt.close()
    # PSNR plots.
    plt.figure(figsize=(10, 7))
    plt.plot(train_psnr, color='green', label='train PSNR dB')
    plt.plot(val_psnr, color='blue', label='validataion PSNR dB')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    plt.savefig('/media/hero/Study/User/Study/data/model_out/image_out/psnr.png')
    plt.close()

In [None]:
def save_model_state(model, name_model = "model"):
    # save the model to disk
    print('Saving model...')
    torch.save(model.state_dict(), f'/media/hero/Study/User/Study/data/model_out/image_out/{name_model}.pth')
def save_model(epochs, model, optimizer, criterion):
    """
    Function to save the trained model to disk.
    """
    # Remove the last model checkpoint if present.
    torch.save({
                'epoch': epochs+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, f"/media/hero/Study/User/Study/data/model_out/image_out/model_ckpt.pth")

In [None]:
def save_validation_results(outputs, epoch, batch_iter):
    """
    Function to save the validation reconstructed images.
    """
    save_image(
        outputs, 
        f"/media/hero/Study/User/Study/data/model_out/image_out/val_sr_{epoch}_{batch_iter}.png"
    )

# Tạo Data

In [None]:
def create_patches(
    input_paths, out_hr_path, out_lr_path,
):
    os.makedirs(out_hr_path, exist_ok=True)
    os.makedirs(out_lr_path, exist_ok=True)
    all_paths = os.listdir(input_paths)
    # for input_path in input_paths:
    #     all_paths.extend(glob.glob(f"{input_path}/*"))
    print(f"Creating patches for {len(all_paths)} images")
    for image_name in tqdm(all_paths, total=len(all_paths)):
        image_path = os.path.join(input_paths+"/", image_name)
        # print(image_path)
        # image = Image.open(image_path)
        image = cv2.imread(image_path)
        image_name = image_name.replace(".png", "")
        w, h = image.shape[:2]
        # Create patches of size (32, 32, 3)
        patches = patchify.patchify(np.array(image), (64, 64, 3), STRIDE)

        counter = 0
        for i in range(patches.shape[0]):
            for j in range(patches.shape[1]):
                counter += 1
                patch = patches[i, j, 0, :, :, :]
                patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
                # print(f"{out_hr_path}/{image_name}_{counter}.png")
                cv2.imwrite(
                    f"{out_hr_path}/{image_name}_{counter}.png",
                    patch
                )
                # Convert to bicubic and save.
                h, w, _ = patch.shape
                low_res_img = cv2.resize(patch, (int(w*0.25), int(h*0.25)), 
                                        interpolation=cv2.INTER_CUBIC)
                # Now upscale using BICUBIC.
                high_res_upscale = cv2.resize(low_res_img, (w, h), 
                                            interpolation=cv2.INTER_CUBIC)
                cv2.imwrite(
                    f"{out_lr_path}/{image_name}_{counter}.png",
                    high_res_upscale
                )
    if SHOW_PATCHES:
        show_patches(patches)

In [None]:
input_paths = "/media/hero/Study/User/Study/data/upscale_image/T91"
out_hr_path = "/media/hero/Study/User/Study/data/upscale_image/t91_hr_patches"
out_lr_path = "/media/hero/Study/User/Study/data/upscale_image/t91_lr_patches"
# create_patches(input_paths, out_hr_path, out_lr_path)

In [None]:
scale_factor = 0.25
os.makedirs('/media/hero/Study/User/Study/data/upscale_image/test_bicubic_rgb_xx', exist_ok=True)
os.makedirs('/media/hero/Study/User/Study/data/upscale_image/test_hr_xx', exist_ok=True)
save_path_lr = '/media/hero/Study/User/Study/data/upscale_image/test_bicubic_rgb_xx'
save_path_hr = '/media/hero/Study/User/Study/data/upscale_image/test_hr_xx'

In [None]:
# for image_name in os.listdir(input_paths):
#     image_path = os.path.join(input_paths+"/", image_name)

#     # orig_img = Image.open(image_path)
#     orig_img = cv2.imread(image_path)
#     # image_name = image_name.replace(".png", "")
#     # print(type(orig_img))
#     w, h = orig_img.shape[:2]
#     # print(f"Original image dimensions: {w}, {h}")
#     cv2.imwrite(f"{save_path_lr}/{image_name}",orig_img)

#     low_res_img = cv2.resize(orig_img, (int(h*scale_factor), int(w*scale_factor)))
#     high_res_upscale = cv2.resize(low_res_img, (h, w))
    
#     cv2.imwrite(f"{save_path_hr}/{image_name}",high_res_upscale)

In [None]:
class SRCNNDataset(Dataset):
    def __init__(self, image_paths, label_paths):
        self.all_image_paths = glob.glob(f"{image_paths}/*")
        self.all_label_paths = glob.glob(f"{label_paths}/*") 
        # print(len(self.all_label_paths), len(self.all_image_paths))
    def __len__(self):
        return (len(self.all_image_paths))
    def __getitem__(self, index):
        image = Image.open(self.all_image_paths[index]).convert('RGB')
        label = Image.open(self.all_label_paths[index]).convert('RGB')
        image = np.array(image, dtype=np.float32)
        label = np.array(label, dtype=np.float32)
        # print(image.shape, label.shape)
        image /= 255.
        label /= 255.
        image = image.transpose([2, 0, 1])
        label = label.transpose([2, 0, 1])
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(label, dtype=torch.float)
        )

In [None]:
def get_datasets(
    train_image_paths, train_label_paths,
    valid_image_path, valid_label_paths
):
    dataset_train = SRCNNDataset(
        train_image_paths, train_label_paths
    )
    dataset_valid = SRCNNDataset(
        valid_image_path, valid_label_paths
    )
    return dataset_train, dataset_valid
# Prepare the data loaders
def get_dataloaders(dataset_train, dataset_valid):
    train_loader = DataLoader(
        dataset_train, 
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True
    )
    valid_loader = DataLoader(
        dataset_valid, 
        batch_size=TEST_BATCH_SIZE,
        shuffle=False
    )
    return train_loader, valid_loader

In [None]:
TRAIN_LABEL_PATHS = '/media/hero/Study/User/Study/data/upscale_image/t91_hr_patches'
TRAN_IMAGE_PATHS = '/media/hero/Study/User/Study/data/upscale_image/t91_lr_patches'
VALID_LABEL_PATHS = '/media/hero/Study/User/Study/data/upscale_image/test_hr_xx'
VALID_IMAGE_PATHS = '/media/hero/Study/User/Study/data/upscale_image/test_bicubic_rgb_xx'
SAVE_VALIDATION_RESULTS = True

In [None]:
# dataset_valid = SRCNNDataset(
#     VALID_IMAGE_PATHS, VALID_LABEL_PATHS
# )
# len(dataset_valid)

In [None]:
dataset_train, dataset_valid = get_datasets(
    TRAN_IMAGE_PATHS, TRAIN_LABEL_PATHS,
    VALID_IMAGE_PATHS, VALID_LABEL_PATHS
)
train_loader, valid_loader = get_dataloaders(dataset_train, dataset_valid)
print(f"Training samples: {len(dataset_train)}")
print(f"Validation samples: {len(dataset_valid)}")

# Train mô hình 

In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
        image_data = data[0].to(device)
        label = data[1].to(device)
        
        # Zero grad the optimizer.
        optimizer.zero_grad()
        outputs = model(image_data)
        # print(outputs.shape)

        loss = criterion(outputs, label)
        # Backpropagation.
        loss.backward()
        # Update the parameters.
        optimizer.step()
        # Add loss of each item (total items in a batch = batch size).
        running_loss += loss.item()
        # Calculate batch psnr (once every `batch_size` iterations).
        batch_psnr =  psnr(label, outputs)
        running_psnr += batch_psnr
    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

def validate(model, dataloader, epoch, criterion, device):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    with torch.no_grad():
        for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            image_data = data[0].to(device)
            label = data[1].to(device)
            
            outputs = model(image_data)
            loss = criterion(outputs, label)
            # Add loss of each item (total items in a batch = batch size) .
            running_loss += loss.item()
            # Calculate batch psnr (once every `batch_size` iterations).
            batch_psnr = psnr(label, outputs)
            running_psnr += batch_psnr
            # For saving the batch samples for the validation results
            # every 500 epochs.
            if SAVE_VALIDATION_RESULTS and (epoch % 500) == 0:
                save_validation_results(outputs, epoch, bi)
    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

In [None]:

lr = 0.0003 # Learning rate.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

print(device)
model = SRCNN().to(device)
# model = SRResNet(2).to(device)
# print(device)
# model = SRCNN().to(device)
model.load_state_dict(torch.load("/media/hero/Study/User/Study/data/model_out/image_out/model.pth"))
print(model.eval())
# Optimizer.
optimizer = optim.Adam(model.parameters(), lr=lr)
# optimizer = optim.SGD(model.parameters(), lr=lr)

# Loss function. 
criterion = nn.MSELoss()
# criterion = nn.CrossEntropyLoss()


In [19]:
import time
train_loss, val_loss = [], []
train_psnr, val_psnr = [], []
epochs = 300
start = time.time()
psnr_best_p = 0
psnr_best_v = 0
for epoch in range(epochs):
    print(f"Epoch {epoch + 1} of {epochs}")
    train_epoch_loss, train_epoch_psnr = train(model, train_loader, optimizer,criterion,device)
    val_epoch_loss, val_epoch_psnr = validate(model, valid_loader, epoch+1,criterion,device)

    print(f"Train PSNR: {train_epoch_psnr:.3f}")
    print(f"Val PSNR: {val_epoch_psnr:.3f}")
    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    val_loss.append(val_epoch_loss)
    val_psnr.append(val_epoch_psnr)
    if train_epoch_psnr > psnr_best_p:
        psnr_best_p = train_epoch_psnr
        save_model_state(model,name_model="best_train")
    
    if val_epoch_psnr > psnr_best_v:
        psnr_best_v = val_epoch_psnr
        save_model_state(model,name_model="best_val")
    # Save model with all information every 100 epochs. Can be used 
    # resuming training.
    if (epoch+1) % 100 == 0:
        save_model(epoch, model, optimizer, criterion)
    # Save the model state dictionary only every epoch. Small size, 
    # can be used for inference.
    save_model_state(model)
    # Save the PSNR and loss plots every epoch.
    save_plot(train_loss, val_loss, train_psnr, val_psnr)
end = time.time()
print(f"Finished training in: {((end-start)/60):.3f} minutes")

100%|██████████| 128/128 [00:45<00:00,  2.79it/s]


Train PSNR: 27.019
Val PSNR: 24.122
Saving model...
Epoch 63 of 300


 54%|█████▎    | 68/127 [01:11<00:58,  1.00it/s]

In [None]:
torch.cuda.empty_cache()

In [None]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# print(device)
# model = SRCNN().to(device)
# model.load_state_dict(torch.load("/media/hero/Study/User/Study/data/model_out/image_out/model.pth"))
# model.eval()

In [None]:
# path_test = "/media/hero/Study/User/Project/image_processing/image_general/processed/infinity_ldv_create_a_sketch_of_a_house_under_a_forest_preceded_d26638e1-c9df-4ef5-9ae4-27cbf3d9e1aenum_3.png"
# image = Image.open(path_test).convert('RGB')
# image = np.array(image, dtype=np.float32)
# image = image/255
# image = image.transpose([2, 0, 1])
# image_processed = torch.tensor(image, dtype=torch.float, device=device)
# outputs = model(image_processed)
# out = np.array(outputs.detach().numpy().T*255, dtype=np.uint8)
# out = out.transpose([1,0,2])

# print(out.shape)
# # print(torch.from_numpy(np.asarray(outputs.T)))
# cv2.imwrite("/media/hero/Study/User/Study/data/model_out/image_video/Image_high_quality_.png", out)
# # cv2.imshow("image_high_quality", out)
# cv2.waitKey(0)

