In [None]:
%cd ..

In [None]:
import os
import pickle
from sklearn.datasets import load_digits
from sklearn.preprocessing import OneHotEncoder
from skimage import io, transform
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from src.linear import Linear
from src.loss import *
from src.activation import TanH, Sigmoid, StableSigmoid, Softmax, LogSoftmax, ReLU, TanH, Softplus
from src.encapsulation import Sequential, Optim

np.random.seed(42)


In [None]:
from utils.mltools import *

def normalize_batch_image(X):
    mn = np.min(X)
    mx = np.max(X)
    X_norm = (X - mn) * (1.0 / (mx - mn))
    return X_norm

def load_usps(fn):
    with open(fn, "r") as f:
        f.readline()
        data = [[float(X) for X in l.split()] for l in f if len(l.split()) > 2]
    tmp = np.array(data)
    return normalize_batch_image(tmp[:, 1:]), tmp[:, 0].astype(int)


X_train, y_train = load_usps("data/USPS_train.txt")
X_test, y_test = load_usps("data/USPS_test.txt")

In [None]:
fig, ax = plt.subplots()

encoder = [
    Linear(256, 64),
    TanH(),
]
decoder = [
    Linear(64, 256),
    Sigmoid()
]
net_usps = Sequential(*(encoder + decoder))
optimizer = Optim(net_usps.reset(), BCELoss(), eps=1e-3)
result_df = optimizer.SGD_eval(
    X_train,
    X_train,
    batch_size,
    100,
    test_size=0.33,
    return_dataframe=True,
    online_plot=True,
    patience=None,
)

loss_long_df = pd.melt(
    result_df,
    id_vars="epoch",
    value_vars=["loss_test", "loss_train"],
    value_name="loss",
    var_name="during",
).replace({"loss_test": "test", "loss_train": "train"})
sns.lineplot(data=loss_long_df, x="epoch", y="loss", hue="during", ax=ax)

n = 10
decoded_imgs = net_usps.forward(X_test)
plt.figure(figsize=(20, 4))
fig.set_tight_layout(True)

for i in range(n):
    # find first class of type i
    idx = np.nonzero(y_test==i)[0][0]
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(X_test[idx].reshape(16, 16))
    plt.title("original")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[idx].reshape(16, 16))
    plt.title("reconstructed")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

with open("./notebooks/mnist_100_epoch_simple_net.pkl", "wb") as f:
    pickle.dump(optimizer, f)
