In [0]:
from __future__ import print_function
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from bigdl.orca import init_orca_context, stop_orca_context
from bigdl.orca.learn.pytorch import Estimator
from bigdl.orca.learn.metrics import Accuracy
from bigdl.orca.learn.trigger import EveryEpoch


def train_data_creator(config={}, batch_size=4, download=True, data_dir='./data'):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))])

    trainset = torchvision.datasets.FashionMNIST(root=data_dir,
                                                 download=download,
                                                 train=True,
                                                 transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=0)
    return trainloader


def validation_data_creator(config={}, batch_size=4, download=True, data_dir='./data'):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))])
    testset = torchvision.datasets.FashionMNIST(root=data_dir, train=False,
                                                download=download, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=0)
    return testloader


# helper function to show an image
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def model_creator(config):
    model = Net()
    return model


def optimizer_creator(model, config):
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    return optimizer

In [0]:
cluster_mode = "spark-submit"
backend = "ray" # ray or spark
batch_size = 4
epochs = 2
data_dir = "./data"
download = True
model_dir = "/dbfs/FileStore/model/fashion/"
save_path = model_dir + "fashion.pth"

In [0]:
init_orca_context(cluster_mode=cluster_mode)

In [0]:
# constant for classes
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# plot some random training images
dataiter = iter(train_data_creator(config={}, batch_size=4,
                                   download=download, data_dir=data_dir))
images, labels = dataiter.next()

# create grid of images
img_grid = torchvision.utils.make_grid(images)

# show images
matplotlib_imshow(img_grid, one_channel=True)

# training loss vs. epochs
criterion = nn.CrossEntropyLoss()
batch_size = batch_size
epochs = epochs

orca_estimator = Estimator.from_torch(model=model_creator,
                                      optimizer=optimizer_creator,
                                      loss=criterion,
                                      metrics=[Accuracy()],
                                      model_dir=model_dir,
                                      use_tqdm=True,
                                      backend=backend)

stats = orca_estimator.fit(train_data_creator, epochs=epochs, batch_size=batch_size)

print("Train stats: {}".format(stats))
val_stats = orca_estimator.evaluate(validation_data_creator, batch_size=batch_size)
print("Validation stats: {}".format(val_stats))
    
print("Saving model to: ", save_path)
orca_estimator.save(save_path)
    
# load with orca_estimator.load(save_path)
# orca_estimator.load(save_path)

orca_estimator.shutdown()


In [0]:
stop_orca_context()