In [1]:
import torch
import torch.nn.functional as F
import torchvision

from pytorchtrainutils import trainer
from pytorchtrainutils import utils
from pytorchtrainutils import metrics

device = torch.device('cpu')

In [2]:
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
dataset = torchvision.datasets.MNIST('~/data', train=True, download=True, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False)

In [3]:
mean, std = utils.get_mean_and_std(dataloader)
print('Mean, std =', mean, std)

HBox(children=(FloatProgress(value=0.0, max=600.0), HTML(value='')))


Mean, std = tensor([0.1307]) tensor([0.3015])


In [4]:
class Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, x, y):
        super().__init__()

        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

D_in, H, D_out =  1000, 100, 10
train_x, train_y = torch.randn(100, D_in), torch.max(torch.randn(100, D_out), 1)[1]
val_x, val_y = torch.randn(20, D_in), torch.max(torch.randn(20, D_out), 1)[1]
test_x, test_y = torch.randn(30, D_in), torch.max(torch.randn(30, D_out), 1)[1]


train_dataset = Dataset(train_x, train_y)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, num_workers=0, shuffle=True)
val_dataset = Dataset(val_x, val_y)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, num_workers=0, shuffle=False)
test_dataset = Dataset(test_x, test_y)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, num_workers=0, shuffle=False)

In [5]:
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
    torch.nn.Softmax(dim=1)
)

criterion = F.cross_entropy
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-4)

tracked_metrics = [
    metrics.Accuracy(multiclass=True),
    metrics.FScore(multiclass=True),
    metrics.RocAuc(multiclass=True),
    metrics.ConfusionMatrix(multiclass=True)
]

name = 'runs/test'
best_model = trainer.fit(
    model, train_dataloader=train_loader, val_dataloader=val_loader,
    test_dataloader=test_loader, test_every=2, criterion=criterion,
    optimizer=optimizer, scheduler=None, metrics=tracked_metrics, n_epochs=20,
    name=name, device=device,
    callbacks={'train': lambda: utils.save_cm(
        cm=tracked_metrics[-1], title='train', path=f'{name}/cm-train.png',
        normalized=True, format=".1f", vmin=0., vmax=1.,
        yticklabels=['a']*10
    )})

test_logs = trainer.test(
    best_model, test_dataloader=test_loader,
    criterion=criterion, metrics=tracked_metrics,
    device=device
)

HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 000 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5556 - loss: 2.3013 | TRAIN acc: 0.0500 - f-score: 0.0339 - auc: 0.4856 - loss: 2.3046 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 001 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5556 - loss: 2.3012 | TRAIN acc: 0.0900 - f-score: 0.0658 - auc: 0.5089 - loss: 2.3029 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0843 - auc: 0.4778 - loss: 2.3069


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 002 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5556 - loss: 2.3011 | TRAIN acc: 0.1200 - f-score: 0.0840 - auc: 0.5378 - loss: 2.3011 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 003 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5556 - loss: 2.3011 | TRAIN acc: 0.1400 - f-score: 0.1062 - auc: 0.5667 - loss: 2.2994 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0843 - auc: 0.4815 - loss: 2.3068


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 004 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5500 - loss: 2.3010 | TRAIN acc: 0.1600 - f-score: 0.1308 - auc: 0.5811 - loss: 2.2975 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 005 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5444 - loss: 2.3009 | TRAIN acc: 0.1700 - f-score: 0.1467 - auc: 0.5956 - loss: 2.2956 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0843 - auc: 0.4815 - loss: 2.3068


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 006 | VAL acc: 0.1000 - f-score: 0.1143 - auc: 0.5500 - loss: 2.3009 | TRAIN acc: 0.1900 - f-score: 0.1742 - auc: 0.6133 - loss: 2.2937 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 007 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3008 | TRAIN acc: 0.2200 - f-score: 0.2039 - auc: 0.6378 - loss: 2.2917 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0843 - auc: 0.4778 - loss: 2.3067


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 008 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3007 | TRAIN acc: 0.2400 - f-score: 0.2191 - auc: 0.6544 - loss: 2.2897 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 009 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3007 | TRAIN acc: 0.2700 - f-score: 0.2386 - auc: 0.6822 - loss: 2.2875 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0843 - auc: 0.4815 - loss: 2.3067


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 010 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3006 | TRAIN acc: 0.2800 - f-score: 0.2491 - auc: 0.6989 - loss: 2.2854 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 011 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5556 - loss: 2.3005 | TRAIN acc: 0.3100 - f-score: 0.2727 - auc: 0.7211 - loss: 2.2831 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0869 - auc: 0.4852 - loss: 2.3066


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 012 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3005 | TRAIN acc: 0.3400 - f-score: 0.3114 - auc: 0.7389 - loss: 2.2808 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 013 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5444 - loss: 2.3004 | TRAIN acc: 0.3600 - f-score: 0.3370 - auc: 0.7533 - loss: 2.2784 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0869 - auc: 0.4778 - loss: 2.3066


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 014 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5444 - loss: 2.3003 | TRAIN acc: 0.3800 - f-score: 0.3610 - auc: 0.7711 - loss: 2.2760 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 015 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.3002 | TRAIN acc: 0.4000 - f-score: 0.3806 - auc: 0.7844 - loss: 2.2734 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0869 - auc: 0.4778 - loss: 2.3065


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 016 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5556 - loss: 2.3001 | TRAIN acc: 0.4300 - f-score: 0.4089 - auc: 0.7956 - loss: 2.2708 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 017 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5556 - loss: 2.3000 | TRAIN acc: 0.4500 - f-score: 0.4282 - auc: 0.8089 - loss: 2.2681 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0869 - auc: 0.4815 - loss: 2.3065


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 018 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.2999 | TRAIN acc: 0.4700 - f-score: 0.4479 - auc: 0.8244 - loss: 2.2653 |


HBox(children=(FloatProgress(value=0.0, description='train', max=10.0, style=ProgressStyle(description_width='…



HBox(children=(FloatProgress(value=0.0, description='val', max=2.0, style=ProgressStyle(description_width='ini…

Epoch: 019 | VAL acc: 0.1500 - f-score: 0.1743 - auc: 0.5500 - loss: 2.2998 | TRAIN acc: 0.5100 - f-score: 0.4972 - auc: 0.8367 - loss: 2.2624 |


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0887 - auc: 0.4778 - loss: 2.3065
Training finished


HBox(children=(FloatProgress(value=0.0, description='test', max=3.0, style=ProgressStyle(description_width='in…

TEST | acc: 0.1000 - f-score: 0.0887 - auc: 0.4778 - loss: 2.3065


<Figure size 432x288 with 0 Axes>