In [1]:
# Ignore this. For making the example's imports work in this folder.
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

In [2]:
from torch import Tensor
from torch import inference_mode
from torch.nn import Module
from torch.optim import Optimizer
from torchsystem import Loader
from torchsystem import Aggregate
from typing import Callable

class Classifier(Aggregate):
    def __init__(self, model: Module, criterion: Module, optimizer: Optimizer):
        super().__init__()
        self.epoch = 0
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def forward(self, input: Tensor) -> Tensor:
        return self.model(input)
    
    def loss(self, output: Tensor, target: Tensor) -> Tensor:
        return self.criterion(output, target)

    def fit(self, input: Tensor, target: Tensor) -> tuple[Tensor, float]:
        self.optimizer.zero_grad()
        output = self(input)
        loss = self.loss(output, target)
        loss.backward()
        self.optimizer.step()
        return output, loss.item()

    def evaluate(self, input: Tensor, target: Tensor) -> tuple[Tensor, float]: 
        output = self(input)
        loss = self.loss(output, target)
        return output, loss.item()

In [3]:
def train(aggregate: Aggregate, loader: Loader, callback: Callable, device: str):
    aggregate.phase = 'train'
    for batch, (input, target) in enumerate(loader, start=1):
        input, target = input.to(device), target.to(device)
        output, loss = aggregate.fit(input, target)
        callback(batch, output, target, loss)

def evaluate(aggregate: Aggregate, loader: Loader, callback: Callable, device: str):
    aggregate.phase = 'evaluation'
    with inference_mode():
        for batch, (input, target) in enumerate(loader, start=1):
            input, target = input.to(device), target.to(device)
            output, loss = aggregate.evaluate(input, target)
            callback(batch, output, target, loss)

In [4]:
from torchsystem.metrics import Callback
from torchsystem.metrics.average import Loss, Accuracy

callback = Callback(Loss(), Accuracy())

@callback.handler
def handle_results(batch: int, output: Tensor, target: Tensor, loss: float):
    callback.metrics(loss=loss, predictions=output.argmax(dim=1), target=target)
    if batch % 100 == 0:
        print(f'Batch {batch}: average loss {callback.metrics['loss']:.2f}, average accuracy: {100*callback.metrics['accuracy']:.4f}%')

In [5]:
from torch import cuda
from torchsystem import Compiler
from torchsystem import Depends

compiler = Compiler()

def get_device():
    return 'cuda' if cuda.is_available() else 'cpu'

@compiler.step
def build_classifier(model, criterion, optimizer, device = Depends(get_device)) -> Classifier:
    classifier = Classifier(model, criterion, optimizer).to(device)
    return compiler.compile(classifier)

In [6]:
from examples.basic.model import MLP
from examples.basic.dataset import Fashion
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

model = MLP(784, 256, 10, 0.5, 'relu')
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
loaders = [
    ('train', DataLoader(Fashion(train=True), batch_size=32, shuffle=True)),
    ('evaluation', DataLoader(Fashion(train=False), batch_size=32, shuffle=False))
]
aggregate = compiler(model, criterion, optimizer)

for epoch in range(5):
    for phase, loader in loaders:
        if phase == 'train':
            train(aggregate, loader, callback, get_device())
        else:
            evaluate(aggregate, loader, callback, get_device())
        print(f'Epoch {epoch+1} {phase} completed')
        print(f'Average loss: {callback.metrics["loss"]:.2f}, Average accuracy: {100*callback.metrics["accuracy"]:.4f}%')
        callback.reset()

Batch 100: average loss 1.16, average accuracy: 56.6875%
Batch 200: average loss 0.95, average accuracy: 64.5312%
Batch 300: average loss 0.86, average accuracy: 67.5833%
Batch 400: average loss 0.80, average accuracy: 69.9844%
Batch 500: average loss 0.77, average accuracy: 71.2562%
Batch 600: average loss 0.74, average accuracy: 72.6667%
Batch 700: average loss 0.72, average accuracy: 73.5179%
Batch 800: average loss 0.70, average accuracy: 74.2539%
Batch 900: average loss 0.68, average accuracy: 74.7257%
Batch 1000: average loss 0.67, average accuracy: 75.3406%
Batch 1100: average loss 0.66, average accuracy: 75.8267%
Batch 1200: average loss 0.65, average accuracy: 76.2734%
Batch 1300: average loss 0.64, average accuracy: 76.5168%
Batch 1400: average loss 0.63, average accuracy: 76.7812%
Batch 1500: average loss 0.63, average accuracy: 77.1042%
Batch 1600: average loss 0.62, average accuracy: 77.3887%
Batch 1700: average loss 0.61, average accuracy: 77.5699%
Batch 1800: average los