## Dependencies

In [6]:
import os.path as osp
import glob
import cv2
import numpy as np
import torch
import ESRGAN.RRDBNet_arch as arch # Make sure you clone the ESRGAN repository
import os
from tkinter import Tk, filedialog
import matplotlib.pyplot as plt


model_path = os.path.join(os.getcwd(),'ESRGAN','models','RRDB_ESRGAN_x4.pth') # change this if incorrect
print(model_path)


c:\Projects\python\image-video-super-resolution\ESRGAN\models\RRDB_ESRGAN_x4.pth


## Device setup

In [2]:
if torch.cuda.is_available():    
    device = torch.device("cuda")
    current_device= torch.cuda.current_device()
    # Limit memory usage to 80%
    torch.cuda.set_per_process_memory_fraction(0.8,device=current_device)
    # For optimized memory utilization during tensor operations
    torch.backends.cudnn.benchmark = True
    print(f"Using Device: {device}",
          f"\nCurrent GPU: {torch.cuda.get_device_name(current_device)}",
          f"\nCuda version: {torch.version.cuda}",
          f"\ncuDNN available: {torch.backends.cudnn.is_available()}",
          f"\ncuDNN version: {torch.backends.cudnn.version()}",
          f"\nAllocated memory: {torch.cuda.memory_allocated()} bytes",
          f"\nCached memory: {torch.cuda.memory_reserved()} bytes")
else:
    device = torch.device("cpu")
    print(f"Using Device: {device}")

Using Device: cuda 
Current GPU: NVIDIA GeForce RTX 4060 Laptop GPU 
Cuda version: 11.8 
cuDNN available: True 
cuDNN version: 90100 
Allocated memory: 0 bytes 
Cached memory: 0 bytes


In [3]:
# load the model
model = arch.RRDBNet(in_nc=3,out_nc=3, nf=64,nb=23,gc=32)
model.load_state_dict(torch.load(model_path),strict=True)
model.eval()
model = model.to(device=device)

  model.load_state_dict(torch.load(model_path),strict=True)


In [4]:
os.getcwd()

'c:\\Projects\\python\\image-video-super-resolution'

## Main

In [7]:
def super_resolution(model,img_paths):
    os.makedirs(os.path.join(os.getcwd,'outputs'),exist_ok=True)
    fig, axes = plt.subplots(len(img_paths),2,figsize=(10,5*len(img_paths)))
    if len(img_paths)==1:
        axes=[axes]
    
    for idx, path in enumerate(img_paths):
        # Read and process the image
        img = cv2.imread(path,cv2.IMREAD_COLOR)
        img = img/255.0
        img_tensor = torch.from_numpy(np.transpose(img[:,:,[2,1,0]],(2,0,1))).float()
        img_tensor = img_tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img_tensor).data.squeeze().float().cpu().clamp_(0,1).numpy()

        # convert the output back to numpy format
        output_img = np.transpose(output[[2,1,0],:,:],(1,2,0))
        output_img = (output_img*255.0).astype(np.uint8)

        # save the output
        base_name = os.path.splitext(os.path.basename(path))[0]
        output_path = os.makedirs(os.path.join(os.getcwd,'outputs'))
        cv2.imwrite(output_path,output_img)
        # try:
        #     output_path = os.makedirs(os.path.join(os.getcwd,'outputs'))
        #     cv2.imwrite(output_path,output_img)
        # except:
        #     print('Image already exists.')
        
        # plot the original and super-resolution images
        axes[idx][0].imshow(cv2.cvtColor(cv2.imread(path),cv2.COLOR_BGR2RGB))
        axes[idx][0].set_title(f'Original {base_name}')
        axes[idx][0].axis('off')

        axes[idx][1].imshow(output_img)
        axes[idx][1].set_title(f'Super resolution {base_name}')
        axes[idx][1].axis('off')
    plt.tight_layout()
    plt.show()   

In [None]:
Tk().withdraw() # hide main tkinter window
image_paths = filedialog.askopenfilenames(title="Select Images",filetypes=[("Image files", "*.jpg;*.png;*.jpeg")])

if image_paths:
    print(f'Selected {len(image_paths)} image(s)')
    for path in image_paths:
        print(path)
    
    super_resolution(model=model,img_paths=image_paths)
else:
    print('No Images Selected')

In [None]:
test_img_folder = 'LR/*'

model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)

print('Model path {:s}. \nTesting...'.format(model_path))

idx = 0
for path in glob.glob(test_img_folder):
    idx += 1
    base = osp.splitext(osp.basename(path))[0]
    print(idx, base)
    # read images
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = img * 1.0 / 255
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img_LR = img.unsqueeze(0)
    img_LR = img_LR.to(device)

    with torch.no_grad():
        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
    output = (output * 255.0).round()
    cv2.imwrite('results/{:s}_rlt.png'.format(base), output)
