In [None]:
import pandas as pd
import torch
import torchvision
from torchvision import transforms
from cloudpickle import dumps

def handler(context, event):
    try:
        print(context.inputs['img_dimensions'].get())
        print(type(context.inputs['img_dimensions'].get()))
    except:
        pass
    
    # Normalize to the ImageNet mean and standard deviation
    # Could calculate it for the cats/dogs data set, but the ImageNet
    # values give acceptable results here.
    img_dimensions = int(str(context.inputs['img_dimensions']))
    img_transforms = transforms.Compose([
        transforms.Resize((img_dimensions, img_dimensions)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225] )
        ]) 

    # Load dataset
    context.logger.info("Loading/Transforming Dataset")
    data = torchvision.datasets.ImageFolder(root=str(context.inputs['data_download_path']),
                                            transform=img_transforms)
    
    # Split into train/test/validation
    context.logger.info("Splitting Dataset")
    train_pct = float(str(context.inputs['train_pct']))
    val_pct = float(str(context.inputs['val_pct']))
    test_pct = float(str(context.inputs['test_pct']))
    
    splits = [int(train_pct * len(data)),
              int(val_pct * len(data)),
              int(test_pct * len(data))]

    train_data, test_data, validation_data = torch.utils.data.dataset.random_split(data, splits)

    # Create DataLoaders per split
    context.logger.info("Creating DataLoaders")
    num_workers = 0
    batch_size = int(str(context.inputs['batch_size']))
    train_data_loader = torch.utils.data.DataLoader(train_data,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    num_workers=num_workers)
    validation_data_loader = torch.utils.data.DataLoader(validation_data,
                                                         batch_size=batch_size,
                                                         shuffle=True,
                                                         num_workers=num_workers)
    test_data_loader = torch.utils.data.DataLoader(test_data,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
    
    # Output DataLoaders (train/test/validation)
    context.logger.info("Logging DataLoaders")
    context.log_artifact(item="train_data_loader",
                         body=dumps(train_data_loader),
                         local_path="train_data_loader.pkl",
                         artifact_path=context.artifact_path)
    context.log_artifact(item="validation_data_loader",
                         body=dumps(validation_data_loader),
                         local_path="validation_data_loader.pkl",
                         artifact_path=context.artifact_path)
    context.log_artifact(item="test_data_loader",
                         body=dumps(test_data_loader),
                         local_path="test_data_loader.pkl",
                         artifact_path=context.artifact_path)