In [None]:
#!yes | pip uninstall torchvison
!pip install -qU torchvision
!pip install sagemaker

# Pipe-Squeeze Experiments on CIFAR-10 with SageMaker

## Contents

1. [Background](#Background)
1. [Data](#Data)
1. [Train](#Train)
1. [Host](#Host)

---

## Background

Setup for running Pipe-Squeeze experiments on CIFAR-10. Required that you setup SageMaker prior to running this notebook.

In [None]:
import sagemaker
import os

os.environ["AWS_ACCESS_KEY_ID"] = "<YOUR_ACCESS_KEY_ID>"
os.environ["AWS_SECRET_ACCESS_KEY"] = "<YOUR_SECRET_ACCESS_KEY>"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"

sagemaker_session = sagemaker.Session()
bucket = "<YOUR_S3_BUCKET>"
prefix = "cifar10"

role = "<YOUR_AWS_SAGEMAKER_ROLE>"

## Data - CIFAR-10
### Getting the data



In [None]:
from torchvision import datasets, transforms, models

In [None]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
val_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_set = datasets.CIFAR10(root='../data/cifar10/train', train=True, download=True, 
                    transform=train_transform)
val_set = datasets.CIFAR10(root='../data/cifar10/val', train=False, download=True, 
                    transform=val_transform)

### Uploading the data to S3
We are going to use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use later when we start the training job.


In [None]:
inputs = sagemaker_session.upload_data(path="../data/cifar10", bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))

## Train
### Training script
The `train_sagemaker.py` script provides all the code we need for training and hosting a SageMaker model (`model_fn` function to load a model).

### Run training in SageMaker

In [None]:
from sagemaker.pytorch import PyTorch

n_pipelines = 3
epochs = 10
batch_size = 1000
microbatches = 8
learning_rate = 0.003
compression_type = 'randomk'
compression_ratio = 0.8
pt_estimator = PyTorch(
    entry_point="train_sagemaker.py",
    source_dir="models",
    role=role,
    instance_count=n_pipelines,
    instance_type="ml.g4dn.12xlarge",
    framework_version='1.12.1',
    py_version='py38',
    hyperparameters={
        "epochs": epochs,
        "backend": "nccl",
        "batch-size": batch_size,
        "n-microbatches": microbatches,
        "learning-rate": learning_rate,
        "compression-type": compression_type,
        "compression-ratio": compression_ratio,
    },
    base_job_name=f"three-pipes-{compression_type}",
)

pt_estimator.fit(f"s3://{bucket}/{prefix}")


In [None]:

experiment_dict = {'compression_type': ['randomk', 'randomk', 'randomk', 'powersgd'],
                    'compression_ratio': [0.3, 0.5, 0.8, 1],
                    }
n_pipelines = 3

for i in range(4):
    pt_estimator = PyTorch(
        entry_point="train_sagemaker.py",
        source_dir="models",
        role=role,
        instance_count=n_pipelines,
        instance_type="ml.g4dn.12xlarge",
        framework_version='1.12.1',
        py_version='py38',
        hyperparameters={
            "epochs": epochs,
            "backend": "nccl",
            "batch-size": batch_size,
            "n-microbatches": microbatches,
            "learning-rate": learning_rate,
            "compression-type": experiment_dict['compression_type'][i],
            "compression-ratio": experiment_dict['compression_ratio'][i],
        },
        base_job_name=f"three-pipes-{compression_type}",
    )
    pt_estimator.fit(f"s3://{bucket}/{prefix}")

After we've constructed our `PyTorch` object, we can fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our training script can simply read the data from disk.
