### Carga de modulos

In [None]:
from bigdl.nn.layer import *
from bigdl.nn.criterion import *
from bigdl.optim.optimizer import *
from datetime import datetime as dt
from matplotlib import pyplot as plt
from pyspark import SparkContext
import numpy as np

### Inicialización del Spark Context

In [None]:
sc=SparkContext.getOrCreate(conf=create_spark_conf().setMaster("local[*]"))
init_engine()

### Carga del dataset

In [None]:
from keras.datasets import mnist

(X_train, y_train_lab), (X_test, y_test_lab) = mnist.load_data()
X_train = X_train.reshape(60000,28,28,1)
X_test = X_test.reshape(10000,28,28,1)

rdd_train_images = sc.parallelize(X_train)
rdd_train_labels = sc.parallelize(y_train_lab)
rdd_test_images = sc.parallelize(X_test)
rdd_test_labels = sc.parallelize(y_test_lab)

rdd_train = rdd_train_images.zip(rdd_train_labels).map(lambda x: Sample.from_ndarray(features=x[0] / 255,
                                                                                     labels=x[1] + 1))
rdd_test = rdd_test_images.zip(rdd_test_labels).map(lambda x: Sample.from_ndarray(features=x[0] / 255,
                                                                                  labels=x[1] + 1))

### Ejemplo de algunas imagenes

In [None]:
def plot_images_sample(X):
    n_images = 9 
    seed = np.random.randint(low=1, high=np.iinfo(np.int32).max)
    lst_pred = X.takeSample(withReplacement=False, num=9, seed=seed)
    fig, axes = plt.subplots(3,3, 
                         figsize=(5,5),
                         sharex=True, sharey=True,
                         subplot_kw=dict(aspect='equal'))
    
    for i, image in enumerate(lst_pred):
    
        row = i//3 
        col = i%3  

        ax = axes[row, col]
        img_plot = np.reshape(image[0], (28,28))
        ax.imshow(img_plot, cmap='gray_r')
        ax.set_title('Label: {}'.format(image[1]))
        ax.set_xbound([0,28])
    
    plt.tight_layout()
    plt.show()

In [None]:
rdd_to_plot = rdd_train_images.zip(rdd_train_labels)
plot_images_sample(rdd_to_plot)

### Definición de modelo (LeNet5) con un cambiós sobre la primera convolución

<img src="LeNet_Original_Image.jpg"> 

In [None]:
model = Sequential()
model.add(Reshape([1, 28, 28]))
model.add(SpatialConvolution(n_input_plane=1, n_output_plane=6, 
                             kernel_h=5, kernel_w=5,
                             stride_w=1, stride_h=1,
                             pad_w=-1, pad_h=-1).set_name('conv1'))
model.add(ReLU())
model.add(SpatialMaxPooling(2, 2, 2, 2).set_name('pool1'))
model.add(SpatialConvolution(n_input_plane=6, n_output_plane=16, 
                             kernel_h=5, kernel_w=5).set_name('conv2'))
model.add(SpatialMaxPooling(2, 2, 2, 2).set_name('pool2'))
model.add(Reshape([400]))
model.add(Linear(400, 120).set_name('fc1'))
model.add(ReLU())
model.add(Linear(120, 84).set_name('fc1'))
model.add(ReLU())
model.add(Linear(84, 10).set_name('score'))
model.add(SoftMax())

### Logs para tensorboard

In [None]:
all_logs_path = os.path.join(os.getcwd(), 'logs')
if not os.path.exists(all_logs_path):
    os.mkdir(all_logs_path)

now = dt.now()
str_now = now.strftime('%Y-%m-%d_%H:%M:%S')
app_name='MNist_BigDL_{}'.format(str_now)
train_summary = TrainSummary(log_dir=all_logs_path,
                                     app_name=app_name)
train_summary.set_summary_trigger("Parameters", SeveralIteration(50))
val_summary = ValidationSummary(log_dir=all_logs_path,
                                        app_name=app_name)

### Entrenamiento del modelo

In [None]:
optimizer = Optimizer(model=model, training_rdd=rdd_train, criterion=CrossEntropyCriterion(),
                      optim_method=Adam(learningrate=1e-3, learningrate_decay=0.0, 
                                        beta1=0.9, beta2=0.999, epsilon=1e-8, bigdl_type="float"),
                      end_trigger=MaxEpoch(4),
                      batch_size=64)

optimizer.set_validation(batch_size=64, val_rdd=rdd_test,
                         trigger=EveryEpoch(), val_method=[Top1Accuracy()])

optimizer.set_train_summary(train_summary)
optimizer.set_val_summary(val_summary)

In [None]:
start = dt.now()
trained_model = optimizer.optimize()
print(dt.now() - start)

### Evaluación de algunas predicciones

In [None]:
rdd_pred = trained_model.predict(rdd_test)
rdd_pred = rdd_pred.map(lambda x: np.argmax(x))
X_test_and_pred = rdd_test_images.zip(rdd_pred)
plot_images_sample(X_test_and_pred)

### Guardar y cargar modelo

In [None]:
#Guardar modelo
path_saved_models = os.path.join(os.getcwd(), 'saved_models')
model_path = os.path.join(path_saved_models, 'MNist_BigDL.bigdl')
model.save(model_path, True)

#Carga de modelo
model_load = Model.load(model_path, model_path)

### Predicciones modelo cargado

In [None]:
rdd_pred = model_load.predict(rdd_test)
rdd_pred = rdd_pred.map(lambda x: np.argmax(x))
X_test_and_pred = rdd_test_images.zip(rdd_pred)
plot_images_sample(X_test_and_pred)