In [3]:
%pip install -q datasets evaluate ipywidgets

Note: you may need to restart the kernel to use updated packages.


In [4]:
%pip install -q wandb==0.17.2 matplotlib setuptools scikit-learn

Note: you may need to restart the kernel to use updated packages.


If running this notebook in Colab, please ensure that your Hugging Face `HF_TOKEN` and your Weights & Biases `WANDB_API_KEY` are added to your Colab secrets.

Alternatively, please login to Hugging Face and Weights & Biases by running the following two cells.

In [3]:
# !huggingface-cli login

In [4]:
# !wandb login

In [1]:
import os
import random
import numpy as np
import torch

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [2]:
from datasets import load_dataset

iris = load_dataset("scikit-learn/iris")
iris

DatasetDict({
    train: Dataset({
        features: ['Id', 'SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species'],
        num_rows: 150
    })
})

In [3]:
features = iris['train'].features
features

{'Id': Value(dtype='int64', id=None),
 'SepalLengthCm': Value(dtype='float64', id=None),
 'SepalWidthCm': Value(dtype='float64', id=None),
 'PetalLengthCm': Value(dtype='float64', id=None),
 'PetalWidthCm': Value(dtype='float64', id=None),
 'Species': Value(dtype='string', id=None)}

In [4]:
iris.set_format("pandas")
iris_df = iris['train'][:]
iris_df

Unnamed: 0,Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
0,1,5.1,3.5,1.4,0.2,Iris-setosa
1,2,4.9,3.0,1.4,0.2,Iris-setosa
2,3,4.7,3.2,1.3,0.2,Iris-setosa
3,4,4.6,3.1,1.5,0.2,Iris-setosa
4,5,5.0,3.6,1.4,0.2,Iris-setosa
...,...,...,...,...,...,...
145,146,6.7,3.0,5.2,2.3,Iris-virginica
146,147,6.3,2.5,5.0,1.9,Iris-virginica
147,148,6.5,3.0,5.2,2.0,Iris-virginica
148,149,6.2,3.4,5.4,2.3,Iris-virginica


In [5]:
iris_df['Species'].value_counts()

Species
Iris-setosa        50
Iris-versicolor    50
Iris-virginica     50
Name: count, dtype: int64

In [6]:
label2id = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
iris_df['Species'] = [label2id[species] for species in iris_df['Species']]
iris_df['Species'].value_counts()

Species
0    50
1    50
2    50
Name: count, dtype: int64

In [7]:
iris_df[['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']].describe()

Unnamed: 0,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm
count,150.0,150.0,150.0,150.0
mean,5.843333,3.054,3.758667,1.198667
std,0.828066,0.433594,1.76442,0.763161
min,4.3,2.0,1.0,0.1
25%,5.1,2.8,1.6,0.3
50%,5.8,3.0,4.35,1.3
75%,6.4,3.3,5.1,1.8
max,7.9,4.4,6.9,2.5


In [8]:
iris_df = iris_df.sample(frac=1, replace=False, random_state=42).reset_index(drop=True)
iris_df

Unnamed: 0,Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
0,74,6.1,2.8,4.7,1.2,1
1,19,5.7,3.8,1.7,0.3,0
2,119,7.7,2.6,6.9,2.3,2
3,79,6.0,2.9,4.5,1.5,1
4,77,6.8,2.8,4.8,1.4,1
...,...,...,...,...,...,...
145,72,6.1,2.8,4.0,1.3,1
146,107,4.9,2.5,4.5,1.7,2
147,15,5.8,4.0,1.2,0.2,0
148,93,5.8,2.6,4.0,1.2,1


In [9]:
X = iris_df[['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']].values
y = iris_df['Species'].values
X.shape, y.shape

((150, 4), (150,))

In [10]:
from sklearn.model_selection import train_test_split

X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42)
X_train_full.shape, X_test.shape, y_train_full.shape, y_test.shape

((135, 4), (15, 4), (135,), (15,))

In [11]:
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, test_size=0.1, stratify=y_train_full, random_state=42)
X_train.shape, X_valid.shape, y_train.shape, y_valid.shape

((121, 4), (14, 4), (121,), (14,))

In [12]:
X_means, X_stds = X_train.mean(axis=0), X_train.std(axis=0)
X_means, X_stds

(array([5.82479339, 3.02809917, 3.7214876 , 1.18264463]),
 array([0.82193895, 0.43980145, 1.75476714, 0.75434583]))

In [13]:
X_train = (X_train - X_means) / X_stds
X_valid = (X_valid - X_means) / X_stds
X_test = (X_test - X_means) / X_stds
X_train.shape, X_valid.shape, X_test.shape

((121, 4), (14, 4), (15, 4))

In [14]:
from torch.utils.data import Dataset, DataLoader

class IrisDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.int64)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [15]:
train_set = IrisDataset(X_train, y_train)
len(train_set)

121

In [16]:
# Sanity check:
train_set[0]

(tensor([ 1.0648, -1.2008,  1.1845,  0.8184]), tensor(2))

In [17]:
valid_set = IrisDataset(X_valid, y_valid)
test_set = IrisDataset(X_test, y_test)
len(valid_set), len(test_set)

(14, 15)

In [18]:
batch_size = 8
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
len(train_loader)

15

In [19]:
# Sanity check:
x_batch, y_batch = next(iter(train_loader))
x_batch.shape, y_batch.shape

(torch.Size([8, 4]), torch.Size([8]))

In [20]:
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
len(valid_loader), len(test_loader)

(2, 2)

In [None]:
# If you modified the futhark model, you will need to restart the kernel
# and re-run the notebook to get the updated model.

!rm -rf _model.* model.c model.h
!futhark multicore --library model.fut
!build_futhark_ffi model

In [21]:
import numpy as np
import _model
from futhark_ffi import Futhark

model = Futhark(_model)

In [22]:
import sys
sys.path.append("..")
from deeplearning_utils import parse_spec

In [23]:
# Get futhark model weights specification
spec = parse_spec(model.from_futhark(model.specs()))
spec

{'fc1.weight': {'tpe': 'f32', 'dims': [5, 4]},
 'fc1.bias': {'tpe': 'f32', 'dims': [5]},
 'fc2.weight': {'tpe': 'f32', 'dims': [3, 5]},
 'fc2.bias': {'tpe': 'f32', 'dims': [3]}}

In [24]:
from tqdm.autonotebook import tqdm, trange

def train_epoch(ws):
    #model.train()
    train_loss = 0
    #ws = mnist.init_weights()
    #TODO fixme to group batches as much as possible
    for x_batch, y_batch in tqdm(train_loader, desc="Training"):
        x_batch = x_batch.numpy()
        y_batch = y_batch.numpy()
        (ws, loss) = model.train(ws, batch_size, x_batch, y_batch)
        train_loss += loss
    train_loss /= len(train_set)
    train_loss = round(train_loss, 4)
    return (ws, train_loss)

In [25]:
def validate_epoch(ws):
    valid_loss = 0
    valid_acc = 0
    #TODO fixme to group batches as much as possible
    for x_batch, y_batch in tqdm(valid_loader, desc="Validation"):
        x_batch = x_batch.numpy()
        y_batch = y_batch.numpy()
        (acc, loss) = model.validate(ws, x_batch, y_batch)
        valid_loss += loss
        valid_acc += acc
    valid_loss /= len(valid_set)
    valid_loss = round(valid_loss, 4)
    valid_acc /= len(valid_set)/batch_size
    valid_acc = round(valid_acc, 4)
    return valid_loss, valid_acc

In [26]:
import wandb

n_epochs = 100

wandb_config = {
    'architecture': "MLP",
    'dataset': "Iris",
    'batch_size': batch_size,
    'learning_rate': 0.01,
    'n_epochs': n_epochs,
}
wandb.init(
    project="futhark-mlp-iris",
    config=wandb_config,
    notes="Logging min of `train_loss` & `valid_loss`, and max of `accuracy`."
)
wandb.define_metric("train_loss", summary="min")
wandb.define_metric("valid_loss", summary="min")
wandb.define_metric("accuracy", summary="max")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mleonard-schneider[0m. Use [1m`wandb login --relogin`[0m to force relogin


<wandb.sdk.wandb_metric.Metric at 0x13f8f6ad0>

In [27]:
from deeplearning_utils import push_to_HF_hub

best_epoch = 0
best_acc = 0
ws = model.default_weights()
for epoch in trange(n_epochs, desc="Epoch"):
    (ws, train_loss) = train_epoch(ws)

    valid_loss, acc = validate_epoch(ws)

    wandb.log({'train_loss': train_loss, 'valid_loss': valid_loss, 'accuracy': acc})

    tqdm.write(f"Epoch: {epoch}, Training Loss: {train_loss}, Validation Loss: {valid_loss}, Accuracy: {acc}")

    if acc > best_acc:
        best_epoch = epoch
        best_acc = acc
        push_to_HF_hub(
            model,
            "leoschneider/futhark-mlp-iris",
            "model.fut",
            ws,
            commit_message=f"epoch: {epoch}, accuracy: {acc}",
            token=os.environ["HF_TOKEN"]
        )
wandb.finish()
print("---")
print("Done!")
print(f"Best Epoch: {best_epoch}, Best Accuracy: {best_acc}")

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 0, Training Loss: 1.2037, Validation Loss: 1.2295, Accuracy: 2.8571


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 1.1231, Validation Loss: 1.1522, Accuracy: 1.7143


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 1.0587, Validation Loss: 1.0932, Accuracy: 1.7143


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 1.0176, Validation Loss: 1.0456, Accuracy: 2.8571


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 0.972, Validation Loss: 1.0069, Accuracy: 2.8571


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 0.9428, Validation Loss: 0.9748, Accuracy: 4.0


No files have been modified since last commit. Skipping to prevent empty commit.


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 0.9164, Validation Loss: 0.9485, Accuracy: 4.0


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 0.889, Validation Loss: 0.9252, Accuracy: 5.1429


No files have been modified since last commit. Skipping to prevent empty commit.


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 0.873, Validation Loss: 0.9052, Accuracy: 5.7143


No files have been modified since last commit. Skipping to prevent empty commit.


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 0.86, Validation Loss: 0.8878, Accuracy: 5.7143


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 0.8414, Validation Loss: 0.8722, Accuracy: 5.7143


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 11, Training Loss: 0.8266, Validation Loss: 0.8581, Accuracy: 5.7143


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 12, Training Loss: 0.8122, Validation Loss: 0.8454, Accuracy: 6.2857


No files have been modified since last commit. Skipping to prevent empty commit.


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 13, Training Loss: 0.8022, Validation Loss: 0.8339, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 14, Training Loss: 0.7896, Validation Loss: 0.8231, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 15, Training Loss: 0.779, Validation Loss: 0.8129, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 16, Training Loss: 0.7695, Validation Loss: 0.8031, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 17, Training Loss: 0.7593, Validation Loss: 0.7938, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 18, Training Loss: 0.7497, Validation Loss: 0.7846, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 19, Training Loss: 0.7444, Validation Loss: 0.7756, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 20, Training Loss: 0.7356, Validation Loss: 0.767, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 21, Training Loss: 0.7229, Validation Loss: 0.7589, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 22, Training Loss: 0.7147, Validation Loss: 0.7508, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 23, Training Loss: 0.7065, Validation Loss: 0.7426, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 24, Training Loss: 0.7029, Validation Loss: 0.7348, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 25, Training Loss: 0.693, Validation Loss: 0.7272, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 26, Training Loss: 0.6828, Validation Loss: 0.72, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 27, Training Loss: 0.6743, Validation Loss: 0.7128, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 28, Training Loss: 0.6614, Validation Loss: 0.7053, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 29, Training Loss: 0.6572, Validation Loss: 0.6979, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 30, Training Loss: 0.6517, Validation Loss: 0.6908, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 31, Training Loss: 0.6408, Validation Loss: 0.684, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 32, Training Loss: 0.6368, Validation Loss: 0.6771, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 33, Training Loss: 0.6352, Validation Loss: 0.67, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 34, Training Loss: 0.6218, Validation Loss: 0.6628, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 35, Training Loss: 0.6111, Validation Loss: 0.6556, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 36, Training Loss: 0.6047, Validation Loss: 0.649, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 37, Training Loss: 0.6041, Validation Loss: 0.643, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 38, Training Loss: 0.6002, Validation Loss: 0.6372, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 39, Training Loss: 0.5863, Validation Loss: 0.632, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 40, Training Loss: 0.583, Validation Loss: 0.6264, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 41, Training Loss: 0.5795, Validation Loss: 0.6201, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 42, Training Loss: 0.5674, Validation Loss: 0.6141, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 43, Training Loss: 0.5667, Validation Loss: 0.6078, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 44, Training Loss: 0.5671, Validation Loss: 0.6023, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 45, Training Loss: 0.5578, Validation Loss: 0.5968, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 46, Training Loss: 0.5534, Validation Loss: 0.5909, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 47, Training Loss: 0.5483, Validation Loss: 0.5858, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 48, Training Loss: 0.5433, Validation Loss: 0.5807, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 49, Training Loss: 0.5405, Validation Loss: 0.5748, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 50, Training Loss: 0.5337, Validation Loss: 0.5699, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 51, Training Loss: 0.5295, Validation Loss: 0.5645, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 52, Training Loss: 0.5272, Validation Loss: 0.5597, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 53, Training Loss: 0.5124, Validation Loss: 0.555, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 54, Training Loss: 0.5152, Validation Loss: 0.5501, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 55, Training Loss: 0.5111, Validation Loss: 0.5449, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 56, Training Loss: 0.5063, Validation Loss: 0.5408, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 57, Training Loss: 0.5022, Validation Loss: 0.5357, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 58, Training Loss: 0.4996, Validation Loss: 0.5306, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 59, Training Loss: 0.4942, Validation Loss: 0.526, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 60, Training Loss: 0.4897, Validation Loss: 0.5216, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 61, Training Loss: 0.4884, Validation Loss: 0.5175, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 62, Training Loss: 0.4814, Validation Loss: 0.5132, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 63, Training Loss: 0.4773, Validation Loss: 0.5093, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 64, Training Loss: 0.4741, Validation Loss: 0.5055, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 65, Training Loss: 0.471, Validation Loss: 0.5012, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 66, Training Loss: 0.466, Validation Loss: 0.497, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 67, Training Loss: 0.4628, Validation Loss: 0.4921, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 68, Training Loss: 0.4515, Validation Loss: 0.4888, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 69, Training Loss: 0.4561, Validation Loss: 0.4844, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 70, Training Loss: 0.4515, Validation Loss: 0.4804, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 71, Training Loss: 0.4482, Validation Loss: 0.4761, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 72, Training Loss: 0.4451, Validation Loss: 0.4714, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 73, Training Loss: 0.4406, Validation Loss: 0.4676, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 74, Training Loss: 0.4379, Validation Loss: 0.463, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 75, Training Loss: 0.4337, Validation Loss: 0.4594, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 76, Training Loss: 0.428, Validation Loss: 0.4558, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 77, Training Loss: 0.4197, Validation Loss: 0.4527, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 78, Training Loss: 0.4189, Validation Loss: 0.4487, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 79, Training Loss: 0.4227, Validation Loss: 0.4454, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 80, Training Loss: 0.418, Validation Loss: 0.4411, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 81, Training Loss: 0.4161, Validation Loss: 0.4363, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 82, Training Loss: 0.4101, Validation Loss: 0.4328, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 83, Training Loss: 0.4048, Validation Loss: 0.4281, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 84, Training Loss: 0.407, Validation Loss: 0.4248, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 85, Training Loss: 0.4011, Validation Loss: 0.421, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 86, Training Loss: 0.3956, Validation Loss: 0.4178, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 87, Training Loss: 0.3948, Validation Loss: 0.4137, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 88, Training Loss: 0.3904, Validation Loss: 0.4103, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 89, Training Loss: 0.39, Validation Loss: 0.4056, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 90, Training Loss: 0.3875, Validation Loss: 0.4017, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 91, Training Loss: 0.3812, Validation Loss: 0.398, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 92, Training Loss: 0.3725, Validation Loss: 0.3957, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 93, Training Loss: 0.3762, Validation Loss: 0.3924, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 94, Training Loss: 0.3735, Validation Loss: 0.3887, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 95, Training Loss: 0.3707, Validation Loss: 0.3852, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 96, Training Loss: 0.3678, Validation Loss: 0.3813, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 97, Training Loss: 0.3654, Validation Loss: 0.3781, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 98, Training Loss: 0.3625, Validation Loss: 0.374, Accuracy: 6.2857


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 99, Training Loss: 0.3614, Validation Loss: 0.3708, Accuracy: 6.2857




---
Done!
Best Epoch: 12, Best Accuracy: 6.2857


In [32]:
# Load model from HF hub, if you want to skip training
from deeplearning_utils import load_weights_from_hf_hub

ws = load_weights_from_hf_hub(
    model,
    "sadhaklal/mlp-iris",
    "model.safetensors",
    token=os.environ["HF_TOKEN"]
)

Loading fc1.weight with shape [5, 4] and dtype f32
[[-0.10955179  0.10089535 -0.24342752  0.29364133]
 [ 0.33094615 -0.54045284  1.4722185   0.9810878 ]
 [ 0.66089314 -0.64838004  0.8264993   0.6493943 ]
 [ 0.33597267  0.1236271  -0.30570328  0.05329869]
 [-0.11930685 -0.36874515  1.6072415   1.9744782 ]]
Loading fc1.bias with shape [5] and dtype f32
[-0.3946851  -0.83004653  2.2240942  -0.29805657 -1.019051  ]
Loading fc2.weight with shape [3, 5] and dtype f32
[[-0.44170353 -0.21288367 -2.118905    0.38234478 -0.5156934 ]
 [-0.14521228 -0.7494736   1.6898736   0.30908802 -1.2574942 ]
 [-0.14103918  1.7627541  -0.00243688  0.20337397  2.2957842 ]]
Loading fc2.bias with shape [3] and dtype f32
[ 3.2907965  -0.76823306 -2.2013984 ]


In [28]:
from tqdm.autonotebook import tqdm
import evaluate

metric = evaluate.load("accuracy")
for x_batch, y_batch in test_loader:
    x_batch, y_batch = x_batch.numpy(), y_batch.numpy()
    logits = model.from_futhark(model.predict(ws, x_batch))
    preds = logits.argmax(axis=-1)
    metric.add_batch(predictions=preds, references=y_batch)
computed_metric = metric.compute()
acc = round(computed_metric['accuracy'], 4)
print(f"Test Set Accuracy: {acc}")

Test Set Accuracy: 0.9333
