# `cloudpickle` demo

## Train and dump model to pickle file

In [None]:
!cd mnist && python main.py --epochs 1 --dump-model-pkl

'''
import cloudpickle
pickled_lambda = cloudpickle.dumps(model)
with open("model.pkl", "wb") as f:
    f.write(pickled_lambda)
'''

## Plot function

In [None]:
import torchvision
import matplotlib.pyplot as plt
import numpy as np

def show_batch(batch):
    im = torchvision.utils.make_grid(batch)
    plt.imshow(np.transpose(im.numpy(), (1, 2, 0)))

## Load model with `pickle`

In [None]:
import cloudpickle

with open("mnist/model.pkl", "rb") as f:
    model_s = f.read()

model = cloudpickle.loads(model_s)

## Prepare test data

In [None]:
BATCH_SIZE = 4

import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
        batch_size=BATCH_SIZE, shuffle=False, **kwargs)

## Inference

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
with torch.no_grad():
    for data, target in test_dataloader:
        show_batch(data)
        data, target = data.to(device), target.to(device)
        output = model(data).cpu()
        pred = output.argmax(dim=1, keepdim=True)
        print(torch.reshape(pred, (-1,)))
        break