# Setup
Remember to properly set the global variables in `config.py`  

In [None]:
from ebm.config import *
gpu_device = "cuda:1"

# For deterministic training
set_seed(0)

# Tensorboard
Doc @ https://pytorch.org/docs/1.7.1/tensorboard.html?highlight=tensorboard  
I don't use it form within notebook.   
To correctly visualize the runs names, open in TB the parent folder of the runs folders!

# Import & install libs

In [None]:
%load_ext autoreload
%autoreload 2

# Standard libraries
import numpy as np 
from tqdm.notebook import tqdm

## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline 

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.tensorboard import SummaryWriter

# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import make_grid

print("Torch version: " + torch.__version__)

In [None]:
device = torch.device(gpu_device) if torch.cuda.is_available() else torch.device("cpu")
print("Currenly using the device:", device)

# Dataset

In [None]:
# Create dataset folder if not exists
if not os.path.exists(DATASET_PATH):
    os.mkdir(DATASET_PATH)

In [None]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH,
                  train=True,
                  transform=transform,
                  download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH,
                 train=False,
                 transform=transform,
                 download=True)

# Trainer classes

In [None]:
from ebm.train import EBMLangVanilla, EBMLang2Ord
from ebm.models import CNNModel, LeNet

## Test trainer

In [None]:
# Test 1
MODEL_NAME = "langVanilla_test"
MODEL_DESCRIPTION = "This is a debug run"
MODEL_FAMILY = "test"

EBMTrain = EBMLangVanilla(img_shape=(1, 28, 28),
                           cnn=LeNet,
                           batch_size=256,
                           lr=5e-3,
                           weight_decay=1e-3,
                           mcmc_step_size=5e-6,
                           mcmc_steps=2,
                           model_name=MODEL_NAME,
                           model_description=MODEL_DESCRIPTION,
                           model_family=MODEL_FAMILY,
                           overwrite=True,
                           device=gpu_device)
EBMTrain.setup()
EBMTrain.prepare_data(train_set, test_set)

try:
    # Train the model for N epochs
    EBMTrain.fit(2)
finally:
    # Clear
    EBMTrain.clear()

## Reload trained model

### Same *name* and *hyperparams*

In [None]:
MODEL_NAME = "langVanilla_test"
MODEL_DESCRIPTION = "This is a debug run"
MODEL_FAMILY = "test"

EBMTrain = EBMLangVanilla(img_shape=(1, 28, 28),
                           cnn=LeNet,
                           batch_size=256,
                           lr=5e-3,
                           weight_decay=1e-3,
                           mcmc_step_size=5e-6,
                           mcmc_steps=2,
                           model_name=MODEL_NAME,
                           model_description=MODEL_DESCRIPTION,
                           model_family=MODEL_FAMILY,
                           overwrite=True,
                           reload_model=True,
                           device=gpu_device)
EBMTrain.setup()
EBMTrain.prepare_data(train_set, test_set)

Generate some samples from pretrained

In [None]:
mcmc_iter = 20
EBMTrain.final_sampled_images = EBMTrain.tb_mcmc_images(
    batch_size=64, mcmc_steps=mcmc_iter, name="final_images_sample", evaluation=True)
# Plot them
print("Final sample after %d mcmc iterations:" % mcmc_iter)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(EBMTrain.final_sampled_images.permute(1, 2, 0))
plt.show()

In [None]:
# Clear
EBMTrain.clear()

### Reload from given path
Hyperparams to be explicitely set:
- mcmc_step_size
- gpu_device
- cnn  


They have to be the same used during training (except fot GPU dev)

In [None]:
model_root = "saved_models/MNIST/..."
EBMTrain = EBMLangVanilla(mcmc_step_size=1e-3,
                          cnn=CNNModel,
                          reload_model=model_root,
                          device=gpu_device)
EBMTrain.setup()

Generate some samples from pretrained

In [None]:
mcmc_iter = 500
EBMTrain.final_sampled_images = EBMTrain.tb_mcmc_images(
    batch_size=64, mcmc_steps=mcmc_iter, name="final_images_sample", evaluation=True)
# Plot them
print("Final sample after %d mcmc iterations:" % mcmc_iter)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(EBMTrain.final_sampled_images.permute(1, 2, 0))
plt.show()

In [None]:
# Clear
EBMTrain.clear()