# To Test a high resolution Image

In [4]:
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont, ImageColor

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
#models checkpoint
srgan_checkpoint = "./models/checkpoint_srgan.pth_25.tar"
srresnet_checkpoint = "./models/checkpoint_srresnet_129.pth.tar"
#load models
srgan = torch.load(srgan_checkpoint)['generator'].to(device)
srgan.eval()
srresnet = torch.load(srresnet_checkpoint)['model'].to(device)
srresnet.eval()

SRResNet(
  (conv_block1): ConvolutionalBlock(
    (conv_block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
      (1): PReLU(num_parameters=1)
    )
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
      )
      (conv_block2): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [7]:
# Save image file in images folder and uncomment either .png or .jpg
which_img = "erc"
# img = "./images/"+ which_img+".png"
img = "./images/"+ which_img+".jpg"

In [8]:
# If image size is too large and cuda is out of memory then set this to true
halve = False

In [9]:
# Load image, downsample to obtain low-res version
hr_img = Image.open(img, mode="r")
hr_img = hr_img.convert('RGB')
if halve:
    hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)),
                           Image.LANCZOS)

lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)),
                       Image.BICUBIC)

In [10]:
# Bicubic Upsampling
bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

# Super-resolution (SR) with SRResNet
sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
sr_img_srresnet = sr_img_srresnet.squeeze(0).cpu().detach()
sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')

# Super-resolution (SR) with SRGAN
sr_img_srgan = srgan(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')


In [None]:
# Create grid
margin = 40
grid_img = Image.new('RGB', (3 * hr_img.width + 3 * margin, 2 * hr_img.height + 4 * margin), (255,255,255))
draw = ImageDraw.Draw(grid_img)
#Font
textsize = 24
try:
    font = ImageFont.truetype("./fonts/Poppins-Light.ttf", size=textsize)
except OSError:
    print(
        "Defaulting to a terrible font. To use a font of your choice, include the link to its TTF file in the function.")
    font = ImageFont.load_default()

font_color = "black"
img_height = hr_img.height
img_width = hr_img.width

In [None]:
# # Place LR image
grid_img.paste(lr_img, (int(margin+img_width*0.375),int(1.5*margin+img_height*0.875)))
width= draw.textlength("LR Image",font)
draw.text(xy=[margin + img_width / 2 - width / 2, (1.5*margin+img_height*0.875) - textsize - 15], text="LR Image", font=font, fill=font_color)
# Place Bicubic
grid_img.paste(bicubic_img, (2*margin + img_width, margin))
width= draw.textlength("Bicubic",font)
draw.text(xy=[2 * margin + img_width + img_width / 2 - width / 2, margin - textsize - 5], text="Bicubic", font=font, fill=font_color)
# Place SRResNet image
grid_img.paste(sr_img_srresnet, (2*margin + img_width, 2 * margin + img_height))
width= draw.textlength("SRResNet",font)
draw.text(
    xy=[2 * margin + img_width + img_width / 2 - width / 2,
        2 * margin + img_height - textsize - 5],
    text="SRResNet",
    font=font,
    fill=font_color)
# Place SRGAN image
grid_img.paste(sr_img_srgan, (3 * margin + 2* img_width, margin))
width= draw.textlength("SRGAN",font)
draw.text(
    xy=[3 * margin + 2*img_width + img_width / 2 - width / 2, margin - textsize - 5],
    text="SRGAN", font=font, fill=font_color)
# Place Original Image
grid_img.paste(hr_img, (3 * margin + 2 * img_width, 2 * margin + img_height))
width= draw.textlength("Original",font)
draw.text(
    xy=[3 * margin + 2*img_width + img_width / 2 - width / 2,
        2 * margin + img_height - textsize - 5],
    text="Original", font=font, fill=font_color)

In [None]:
# To view
grid_img.show()

In [None]:
# To save
# grid_img.save(f"./images/results/{which_img}.jpg")