In [2]:
"""
Single-Image Super-resolution Generator (for testing)

Instrustion on running the script:
1. Download test data to ../../data/test
2. change 'img_height' and 'img_width' with the size of test data
4. Run the sript using command 'python3 srgan.py'
"""

import argparse
import os
from PIL.Image import Image
import numpy as np
import math
import itertools
import sys
import rasterio
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

from models import *
from datasets import *

import torch.nn as nn
import torch.nn.functional as F
import torch



In [3]:
cuda = torch.cuda.is_available()

In [4]:
cuda = torch.cuda.is_available()

img_height = 128 # LR image height
img_width = 128 #LR imgae width
channels = 3
model = "saved_models/noscale11/generator_199.pth"
source_path = "../../data2/test/"
saved_path = "../../data2/test/result/noscale11"
batch_size = 1
n_cpu = 8


input_shape = (img_height, img_width)

# Initialize generator
generator = GeneratorResNet(in_channels=channels, out_channels=channels)

if cuda:
    generator = generator.cuda()

generator.load_state_dict(torch.load(model))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

hr_shape = (img_height, img_width)


dataloader = DataLoader(
    ImageDatasetNoScale(source_path, hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=False,
    num_workers=n_cpu,
)

print('Generating')
print(len(dataloader))

prev_time = time.time()
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        # Set model input
        lr = Variable(batch["lr"].type(Tensor))
        hr = Variable(batch["hr"].type(Tensor))
        min = -10802 # Got from MinMax Normalization of HR images
        max = 6787   # Got from MinMax Normalization of HR images 
        mean = 0.5
        var = 0.5
        input_path = batch["input_path"]
        img_root, img_name = os.path.split(input_path[0])

        # ------------------
        #  Generating
        # ------------------
        output = generator(lr)
        #save PNG#
        # save_image(output, os.path.join(opt.saved_path, "gen_{}.png".format(img_name)), normalize=True)

        output = (output * var + mean) * (max-min) + min
        output = output.detach().cpu().numpy()
        output = np.float32(output)

        hr_out = (hr * var + mean) * (max-min) + min
        hr_out = hr_out.detach().cpu().numpy()
        hr_out = np.float32(hr_out)
#        print(output.shape)

        # save TIF
        print(np.shape(output[0][0]))

        with rasterio.Env():
            with rasterio.open(input_path[0]) as src: 
                profile = src.meta
                print(profile)
                # t = src.transform*src.transform.scale(1/4, 1/4)
                
                profile.update(
                    width = img_width * 4, 
                    height = img_height * 4,
                    dtype = rasterio.float32,
                    # transform = t 
                )
                print(profile)
            with rasterio.open(os.path.join(saved_path, "gen_{}.tif".format(img_name)), 'w', **profile) as dst:
                dst.write(output[0][0], indexes = 1)
            # with rasterio.open(os.path.join(saved_path, "hr_{}.tif".format(img_name)), 'w', **profile) as dst:
            #     dst.write(hr_out[0][0], indexes = 1)



Generating
400
(128, 128)
{'driver': 'GTiff', 'dtype': 'int16', 'nodata': -9999.0, 'width': 128, 'height': 128, 'count': 1, 'crs': CRS.from_epsg(4326), 'transform': Affine(0.016666666666666666, 0.0, 167.90833333333336,
       0.0, -0.016666666666666663, -44.7083333333333)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -9999.0, 'width': 128, 'height': 128, 'count': 1, 'crs': CRS.from_epsg(4326), 'transform': Affine(0.016666666666666666, 0.0, 167.90833333333336,
       0.0, -0.016666666666666663, -44.7083333333333)}
(128, 128)
{'driver': 'GTiff', 'dtype': 'int16', 'nodata': -9999.0, 'width': 128, 'height': 128, 'count': 1, 'crs': CRS.from_epsg(4326), 'transform': Affine(0.016666666666666666, 0.0, 169.64166666666665,
       0.0, -0.016666666666666663, -45.058333333333294)}
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -9999.0, 'width': 128, 'height': 128, 'count': 1, 'crs': CRS.from_epsg(4326), 'transform': Affine(0.016666666666666666, 0.0, 169.64166666666665,
       0.0, -0.01666