In [None]:
# Import the Necessary Support Libraries

In [None]:
import numpy as np
import pandas as pd
import os
import glob
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data as utils_data
import torch.nn.functional as Func
from torch.autograd import Variable
import warnings
warnings.filterwarnings('ignore')
import preprocess_lib as prelib  # Load Custom Library RLE_decode function
from tqdm import tqdm  # For Progress Bar

# import gc

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)

# Defines Global Parameters and Variables

In [None]:
# Common path variables
train_path = '../train/'
test_path = '../test/'
train_label_path = '../input/train_ship_segmentations.csv'
test_label_path = '../input/test_ship_segmentations.csv'
train_label_data = prelib.LoadMyData(train_label_path, pandas=True)

# Global Variables
testimg = 0
batchsz = 1024  # Define Batch Size for DataLoader Import
sample = 0  # Used for datasample loading
sample_32 = 0  # Used for down-sampling and resizing of original sample data
savetrainlabels = 0  # Boolean

# Supporting Classes and Functions defined

In [None]:
# Resize the image - Pytorch Documentation Tutorial
# https://discuss.pytorch.org/t/resizing-any-simple-direct-way/10316/6
def resize2d(img, size):
    return (Func.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data

# functions to show an image
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
def imshow_2(img, mask, title=None):
    """Imshow for Tensor."""
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    mask = mask.numpy().transpose((1, 2, 0))
    mask = np.clip(mask, 0, 1)
    fig = plt.figure(figsize = (6,6))
    plt.imshow(mask_overlay(img, mask))
    if title is not None:
        plt.title(title)
    plt.pause(0.001)     
    
def imshow_unnorm(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))    

In [None]:
class AirbusDS(utils_data.Dataset):
    """
    A customized data loader.
    """
#     print(utils_data.Dataset)
    def __init__(self, root, transform=None):
        """ Intialize the dataset
        """
        self.filenames = []
        self.root = root
        self.transform = transforms.ToTensor()  # original
#         self.transform = transform  #mod
        filenames = glob.glob(osp.join(train_path, '*.jpg'))
        for fn in filenames:
            self.filenames.append(fn)
        self.len = len(self.filenames)
        
    # You must override __getitem__ and __len__
    def __getitem__(self, index):
        """ Get a sample from the dataset
        """
        image = Image.open(self.filenames[index])
#         sample = self.transform(image)
        return self.transform(image)

    def __len__(self):
        """
        Total number of samples in the dataset
        """
        return self.len

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}


# Instantiate necessary dataset Loading 

In [None]:
# Transformation with normalization
normalized_dataset = AirbusDS(train_path, transform=transforms.Compose([
    #Rescale(224), 
    #ToTensor(), 
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5], 
        std=[0.5, 0.5, 0.5])
#         mean=[0.485, 0.456, 0.406], # This is the cifar10 mean and std
#         std=[0.229, 0.224, 0.225])
    ])
) 

imgloader = utils_data.DataLoader(normalized_dataset, batch_size=batchsz,
                            shuffle=True, num_workers=0)

# Create only single sintance of data iterator to minimise memory leak. 
dataiter = iter(imgloader)

# Added Progress Bar for better presentation
for i_batch in tqdm(range(batchsz)):
    sample = next(dataiter)  # Only iterate the data and pass to     

sample_32 = resize2d(sample, (224,224)).float()


# (Optional) Show Images 

In [None]:
# show images
imshow(torchvision.utils.make_grid(sample))                                
imshow(torchvision.utils.make_grid(sample_32))    
print(sample.size())
print(sample_32.size())

# Save Respective batch data

In [None]:
# Save Batch Data
torch.save(sample,'../input/test_data_64bit_sz768-B01.pt')
torch.save(sample_32,'../input/test_data_32bit_sz224-B01.pt')

if savetrainlabels==True:
    torch.save(train_label_data, '../input/train_label.pt')