# Setup

In [1]:
import typing

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, random_split
import pandas as pd

In [2]:
%load_ext autoreload
%autoreload 2

## prep data

In [3]:
def prep_iris_data(
    path: str = "demos/training_data/iris.csv",
    p_test: float = 0.2,
) -> dict:
    # Load data
    iris_df: pd.DataFrame = pd.read_csv(path)

    # Create a mapping from integer labels to species names
    target_list: list[str] = list(iris_df['species'].unique())
    target_map: dict[str, int] = {name: i for i, name in enumerate(target_list)}

    # Convert species to integer labels
    iris_df['species'] = iris_df['species'].map(target_map)

    # Create tensors
    inputs: torch.Tensor = torch.tensor(iris_df.iloc[:, :-1].values, dtype=torch.float32)
    targets: torch.Tensor = torch.tensor(iris_df['species'].values, dtype=torch.int64)

    # Create a TensorDataset
    dataset: TensorDataset = TensorDataset(inputs, targets)

    # Determine train and test sizes
    total_size: int = len(dataset)
    test_size: int = int(p_test * total_size)
    train_size: int = total_size - test_size

    # Split dataset into train and test sets
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Prepare the final dictionary
    result = {
        "train": train_dataset,
        "test": test_dataset,
        "column_names": iris_df.columns[:-1].tolist(),
        "target_list": target_list,
        "target_map": target_map,
    }

    return result

In [4]:
IRIS_DATA: dict = prep_iris_data()
print(IRIS_DATA)
print(IRIS_DATA['train'][0])
print(IRIS_DATA['test'][0])

{'train': <torch.utils.data.dataset.Subset object at 0x000001E07DB15BB0>, 'test': <torch.utils.data.dataset.Subset object at 0x000001E07C195970>, 'column_names': ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], 'target_list': ['setosa', 'versicolor', 'virginica'], 'target_map': {'setosa': 0, 'versicolor': 1, 'virginica': 2}}
(tensor([6.7000, 3.0000, 5.0000, 1.7000]), tensor(1))
(tensor([5.2000, 4.1000, 1.5000, 0.1000]), tensor(0))


## define DNN

In [5]:
class DNN(nn.Module):
	def __init__(self, d_input: int, d_hidden: int, d_output: int):
		super().__init__()
		self.net: nn.Module = nn.Sequential(
			nn.Linear(d_input, d_hidden),
			nn.ReLU(),
			nn.Linear(d_hidden, d_output),
			nn.Softmax(dim=1)
		)
	
	def forward(self, x: torch.Tensor) -> torch.Tensor:
		return self.net(x)
	
	def predict(self, x: torch.Tensor) -> torch.Tensor:
		return torch.argmax(self.forward(x), dim=1)

# Init model & dataset

In [6]:
# Initialize the DNN model
model = DNN(
    d_input = 4,
    d_hidden = 10,
    d_output = 3,
)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Prepare the Iris dataset
iris_data = prep_iris_data()  # Assuming this calls the function we defined earlier
TRAIN_LOADER = torch.utils.data.DataLoader(iris_data['train'], batch_size=32, shuffle=True)
TEST_LOADER = torch.utils.data.DataLoader(iris_data['test'], batch_size=32, shuffle=False)


# init logger

In [10]:
from trnbl.loggers.local import LocalLogger
# from trnbl.loggers.wandb import WandbLogger
from trnbl.loggers.tensorboard import TensorBoardLogger
from trnbl.training_manager import TrainingManager

In [14]:


logger = LocalLogger(
    project="iris-demo",
    metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
    train_config=dict(
        model=str(model),
        dataset="iris",
        optimizer=str(optimizer),
        criterion=str(criterion),
    ),
    base_path="demos/local",
)

# logger = TensorBoardLogger(log_dir="demos/tensorboard")
	


starting logger


# define evaluation function

In [15]:

def eval_func(model):
    losses = torch.full((len(TEST_LOADER),), fill_value=torch.nan)
    accuracies = torch.full((len(TEST_LOADER),), fill_value=torch.nan)
    for idx, (inputs, targets) in enumerate(TEST_LOADER):
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
        losses[idx] = loss.item()
        accuracies[idx] = accuracy

    return {
        "val/loss": losses.mean().item(),
        "val/acc": accuracies.mean().item(),
    }



# run training loop

In [16]:
epochs: int = 80

with TrainingManager(
    model=model,
    dataloader=TRAIN_LOADER,
    logger=logger,
    epochs=epochs,
    evals={
        "1 epochs": eval_func,
        "0.1 runs": lambda model: logger.get_mem_usage(),
    }.items(),
    checkpoint_interval="50 epochs",
) as tp:

    # Training loop
    model.train()
    for epoch in range(epochs):
        for inputs, targets in TRAIN_LOADER:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
            
            tp.batch_update(
                samples=len(targets),
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )

        tp.epoch_update()


initialized training manager
completed epoch 1/80
completed epoch 2/80
completed epoch 3/80
completed epoch 4/80
completed epoch 5/80
completed epoch 6/80
completed epoch 7/80
completed epoch 8/80
completed epoch 9/80
completed epoch 10/80
completed epoch 11/80
completed epoch 12/80
completed epoch 13/80
completed epoch 14/80
completed epoch 15/80
completed epoch 16/80
completed epoch 17/80
completed epoch 18/80
completed epoch 19/80
completed epoch 20/80
completed epoch 21/80
completed epoch 22/80
completed epoch 23/80
completed epoch 24/80
completed epoch 25/80
completed epoch 26/80
completed epoch 27/80
completed epoch 28/80
completed epoch 29/80
completed epoch 30/80
completed epoch 31/80
completed epoch 32/80
completed epoch 33/80
completed epoch 34/80
completed epoch 35/80
completed epoch 36/80
completed epoch 37/80
completed epoch 38/80
completed epoch 39/80
completed epoch 40/80
completed epoch 41/80
completed epoch 42/80
completed epoch 43/80
completed epoch 44/80
completed ep