In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pytensor

from sklearn.datasets import load_digits
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder

from pytensor_ml.activations import LeakyReLU
from pytensor_ml.layers import Linear, Sequential
from pytensor_ml.loss import CrossEntropy
from pytensor_ml.model import Model
from pytensor_ml.optimizers import ADAGrad

In [None]:
X, y = load_digits(return_X_y=True)
y_onehot = OneHotEncoder().fit_transform(y[:, None]).toarray()
X_normed = MinMaxScaler().fit_transform(X)

In [None]:
# X_in = Input('X_in', shape=(64, ))
X_in = pytensor.tensor.tensor("X_in", shape=(None, 64))

In [None]:
prediction_network = Sequential(
    Linear("Linear_1", n_in=64, n_out=256),
    LeakyReLU(),
    Linear("Linear_2", n_in=256, n_out=128),
    LeakyReLU(),
    Linear("Logits", n_in=128, n_out=10),
)

y_hat = prediction_network(X_in)
model = Model(X_in, y_hat)
model

In [None]:
loss_fn = CrossEntropy(expect_onehot_labels=True, expect_logits=True, reduction="mean")

In [None]:
# optim = SGD(model, loss_fn, ndim_out=2, learning_rate=1e-3)
optim = ADAGrad(model, loss_fn, ndim_out=2, learning_rate=1e-3)

In [None]:
model.initalize_weights()

In [None]:
import itertools

from tqdm.notebook import tqdm

n_obs = X.shape[0]
cutpoints = np.arange(0, n_obs, 1000).tolist()
cutpoints += [n_obs]
batch_slices = list(itertools.pairwise(cutpoints))
loss_history = []
n_epochs = 1000

for _ in tqdm(range(n_epochs)):
    all_idx = np.arange(n_obs)
    np.random.shuffle(all_idx)
    y_epoch = y_onehot[all_idx, :]
    X_epoch = X_normed[all_idx, :]
    for start, stop in batch_slices:
        idx = slice(start, stop)
        loss = optim.step(X_epoch[idx], y_epoch[idx])
        loss_history.append(loss)

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_history)

In [None]:
from scipy.special import softmax

y_hat_logits = model.predict(X_normed)
y_hat_probs = softmax(y_hat_logits, axis=-1)
y_hat = np.argmax(y_hat_probs, axis=-1)

In [None]:
import seaborn as sns

from sklearn.metrics import confusion_matrix

sns.heatmap(confusion_matrix(y, y_hat), annot=True, fmt="0.0f")