# Effective Training

In [1]:
 import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

from src.models import MLP
from src.models import estimate_train
from src.utils import init_dataloader

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = init_dataloader(
    dataset_name='MNIST',
    transform=transform,
    batch_size=64,
    dataset_load_path='data/',
    train_mode=True,
    size=64 * (10000 // 64)
)

test_loader = init_dataloader(
    dataset_name='MNIST',
    transform=transform,
    batch_size=64,
    dataset_load_path='data/',
    train_mode=False,
    size=64 * (10000 // 64)
)

In [3]:
model = MLP(layers_num=2, hidden=256, input_channels=1, input_sizes=(28, 28), classes=10).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.0025)
criterion = nn.CrossEntropyLoss()

estimate_train(model, criterion, train_loader, optimizer, delta=0.001, num_epochs=10, log=True)

Batch [1/156]: loss = 0.7626, delta = 1000000000.0000
Batch [2/156]: loss = 0.4929, delta = 1422.2661
Batch [3/156]: loss = 0.3658, delta = 269.0240
Batch [4/156]: loss = 0.3602, delta = 7.1685
Batch [5/156]: loss = 0.4044, delta = 8.9065
Batch [6/156]: loss = 0.3767, delta = 11.1330
Batch [7/156]: loss = 0.2739, delta = 1.3913
Batch [8/156]: loss = 0.1711, delta = 4.0212
Batch [9/156]: loss = 0.2007, delta = 9.4660
Batch [10/156]: loss = 0.1924, delta = 7.9431
Batch [11/156]: loss = 0.1707, delta = 20.1311
Batch [12/156]: loss = 0.2704, delta = 0.4660
Batch [13/156]: loss = 0.2685, delta = 0.3382
Batch [14/156]: loss = 0.2802, delta = 2.6122
Batch [15/156]: loss = 0.1615, delta = 0.0536
Batch [16/156]: loss = 0.1839, delta = 0.2009
Batch [17/156]: loss = 0.1281, delta = 0.9414
Batch [18/156]: loss = 0.1681, delta = 0.3503
Batch [19/156]: loss = 0.1684, delta = 0.0012
Batch [20/156]: loss = 0.2672, delta = 0.0163
Batch [21/156]: loss = 0.2061, delta = 0.0138
Batch [22/156]: loss = 0.13