In [None]:
import sys
from pathlib import Path

In [None]:
import numpy as np
import tensorflow as tf

from sklearn.model_selection import train_test_split

In [None]:
sys.path.append(str(Path("../ssl").resolve()))

In [None]:
from src.models.lenet5.lenet5 import Lenet5
from src.models.lenet5.lenet5_config import Lenet5Config
from src.trainers.basic.categorical_ce_trainer import CategoricalCETrainer
from src.trainers.basic.categorical_ce_trainer_config import CategoricalCETrainerConfig
from src.data.basic_data_loader.categorical_ce_data_loader_config import CategoricalCEDataLoaderConfig
from src.data.basic_data_loader.categorical_ce_data_loader import CategoricalCEDataLoader

## Set up Experiment

In [None]:
class TrainerConfig(CategoricalCETrainerConfig):
    num_epochs = 5

train_config = TrainerConfig()

In [None]:
class ModelConfig(Lenet5Config):
    input_shape = (32, 32, 3)
    output_shape = 10

model_config = ModelConfig()

In [None]:
class DataLoaderConfig(CategoricalCEDataLoaderConfig):
    batch_size = 64
    num_classes = 10

data_loader_config = DataLoaderConfig()

## Get Datasets

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [None]:
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train,
    stratify = y_train, test_size = 0.20, random_state = 42)

In [None]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_data = tf.data.Dataset.from_tensor_slices((x_val, y_val))

In [None]:
# create train dataset
train_data = CategoricalCEDataLoader(train_data, data_loader_config)(training = True)

In [None]:
# create test dataset
val_data = CategoricalCEDataLoader(val_data, data_loader_config)(training = False)

# Train Model

In [None]:
model = Lenet5(model_config)()

In [None]:
trainer = CategoricalCETrainer(
    model, train_data, train_config,
    val_dataset = val_data)

In [None]:
trainer.train()