In [16]:
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torchvision
import torch
import torch.nn as nn
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tfms


In [62]:
# Root directory for dataset
data_dir = '/data/h5'

root_dir = os.getcwd()+data_dir

# Number of workers for dataloader
# Issues with python for windows when workers>0 WE ON LINUX NOW BBY
workers = 2

# Batch size during training P100 TIME
batch_size = 4

# Spatial size of training images. 
size = 400

#image scaling factor
scale = 2

# Number of channels in the training images. For color images this is 3
nc = 3

# Number of training epochs
num_epochs = 10

# Learning rate for optimizers
lr = 0.00005

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
#device = torch.device("cpu")

img_ext = '.png'

In [74]:
# define transforms
lr_transforms = transforms.Compose([transforms.Resize(size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

hr_transforms = transforms.Compose([transforms.Resize(size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [78]:
class SatSRDataset(Dataset):
    def __init__(self, root_dir, img_ext, scale_factor, lr_transform=None, hr_transform=None):
        """
        :param data_folder: # folder with JSON data files
        :param split: one of 'train' or 'test'
        :param crop_size: crop size of target HR images
        :param scaling_factor: the input LR images will be downsampled from the target HR images by this factor; the scaling done in the super-resolution
        :param img_type: the format for images supplied to the model; see convert_image() in utils.py for available formats
        """
        self.root_dir = root_dir
        self.img_ext = img_ext
        self.scale_factor = scale_factor
        self.hr_images = []
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        
                
        #Get LR file names
        for dirName, _, fileList in os.walk(self.root_dir+'/hr_images'):
            for filename in fileList:
                if filename.endswith(self.img_ext):
                    self.hr_images.append(dirName + '/' + filename)
                    
        self.sz = PIL.Image.open(self.hr_images[0]).size[0]
        self.lr_resize = tfms.Resize(int(self.sz/self.scale_factor))

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        
        #return image pair
        self.hr_img = PIL.Image.open(self.hr_images[idx])
        self.lr_img = self.lr_resize(self.hr_img)
        
        self.sample = [self.lr_transform(self.lr_image), self.hr_transform(self.hr_image)]
                                

        return self.sample

In [79]:
#initialise dataset
satdataset = SatSRDataset(
                     root_dir, img_ext=img_ext, scale_factor=scale,lr_transform=lr_transforms,hr_transform=hr_transforms)

In [80]:
#initialise dataloader 
dataloader = torch.utils.data.DataLoader(
    satdataset, 
    batch_size=batch_size, shuffle=True,
    num_workers=workers, pin_memory=False
)

In [None]:
class UpsampleNet(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(, , 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)