In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [309]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from htools import hdir
from ml_htools.torch_utils import ModelMixin, variable_lr_optimizer, DEVICE

In [260]:
class Data(Dataset):
    
    def __init__(self, n=64, dim=2):
        self.x = torch.rand(n, dim).float()
        self.y = torch.clamp(
            (self.x.mean(1)/2 + torch.randn(n)).round(), 0, 1
        ).unsqueeze(-1)
        
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    
    def __len__(self):
        return len(self.x)

In [489]:
class Model(nn.Module, ModelMixin):
    
    def __init__(self, dim, criterion, callbacks=None, metrics=None):
        super().__init__()
        self.fc = nn.Linear(dim, 1)
        self.criterion = criterion
        self.callbacks = [MetricPrinterCallback()] + (callbacks or [])
        self.metrics = metrics or []
        
    def forward(self, x):
        return self.fc(x)
    
    def fit(self, epochs, loaders, lrs, optim=None, callbacks=None, metrics=None, 
            classification=True, thresh=.5, device=DEVICE):
        if callbacks: self.callbacks.extend(callbacks)
        if metrics: self.metrics.extend(metrics)
        stats = defaultdict(list)
        
        # Set up data loaders, optimizer, and callbacks.
        train_dl, val_dl = loaders
        optim = optim or variable_lr_optimizer(self, lrs=lrs)
        for cb in self.callbacks: cb.on_train_begin()
            
        # Train.
        self.to(device)
        for epoch in range(epochs):
            self.train()
            for i, (xb, yb) in enumerate(train_dl):
                xb, yb = xb.to(device), yb.to(device)
                optim.zero_grad()
                
                # Forward pass.
                y_pred = self(xb)
                loss = self.criterion(y_pred, yb)
                
                # Backward pass.
                loss.backward()
                optim.step()
                
                # For classification problems, round probabilities
                # once instead of in every metric.
                # Keep sklearn pattern with y_true as first argument.
                if classification: y_pred = (y_pred > thresh).float()
                for m in self.metrics: 
                    stats[m.__name__].append(m(yb, y_pred))
                
            val_stats = self.validate(val_dl, classification, thresh)
            for cb in self.callbacks: cb.on_epoch_end(epoch, stats, val_stats)
            
        for cb in self.callbacks: cb.on_train_end()
        return stats
            
    def validate(self, val_dl, classification, thresh):
        val_stats = dict()
        self.eval()
        with torch.no_grad():
            for xb, yb in val_dl:
                y_pred = self(xb)
                loss = self.criterion(y_pred, yb)
                if classification: y_pred = (y_pred > thresh).float()
                for m in self.metrics:
                    val_stats[m.__name__] = m(y_pred, yb)
        return val_stats

In [490]:
class TorchCallback:
    
    def on_train_begin(self):
        pass
    
    def on_train_end(self):
        pass
    
    def on_epoch_begin(self):
        pass

    def on_epoch_end(self):
        pass
    
    def on_batch_begin(self):
        pass
    
    def on_batch_end(self):
        pass

In [491]:
class StopTraining

SyntaxError: invalid syntax (<ipython-input-491-0aa6d34f185e>, line 1)

In [492]:
class EarlyStopper(TorchCallback):
    
    def __init__(self, stat='loss', patience=3):
        self.stat = stat
        self.best_stat = float('inf')
        self.patience = patience
        self.since_improvement = 0
        
    def on_epoch_end(self, loss, stats):
        if loss < self.best_stat:
            self.best_stat = loss
            self.since_improvement = 0
        else:
            self.since_improvement += 1
            if self.since_improvement > self.patience:
                return 

In [493]:
class MetricPrinterCallback(TorchCallback):
    
    def __init__(self):
        pass
    
    def on_epoch_end(self, epoch, stats, val_stats):
        print(f'\nEpoch {epoch}')
        for k, v in stats.items():
            print(k, np.mean(v).round(3), val_stats[k])

In [502]:
def percent_positive(y_true, y_pred):
    return (y_pred == 1).float().numpy().mean()

In [503]:
DIM = 2
metrics = [accuracy_score, precision_score, recall_score, percent_positive]

In [513]:
train = Data(n=32, dim=DIM)
val = Data(n=32, dim=DIM)

dl_train = DataLoader(train, batch_size=8, shuffle=True)
dl_val = DataLoader(val, batch_size=8, shuffle=False)

In [514]:
net = Model(DIM, F.binary_cross_entropy_with_logits, metrics=metrics)
net

Model(
  (fc): Linear(in_features=2, out_features=1, bias=True)
)

In [515]:
x, y = next(iter(dl_train))

In [516]:
x.shape, y.shape

(torch.Size([8, 2]), torch.Size([8, 1]))

In [517]:
F.binary_cross_entropy_with_logits(net(x), y)

tensor(1.1740, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [519]:
stats = net.fit(10, [dl_train, dl_val], [3e-1])


Epoch 0
accuracy_score 0.719 0.75
precision_score 0.25 0.0
recall_score 0.083 0.0
percent_positive 0.031 0.25

Epoch 1
accuracy_score 0.703 0.75
precision_score 0.125 0.0
recall_score 0.042 0.0
percent_positive 0.016 0.25

Epoch 2
accuracy_score 0.698 0.375
precision_score 0.083 0.0
recall_score 0.028 0.0
percent_positive 0.01 0.25

Epoch 3
accuracy_score 0.703 0.25
precision_score 0.125 0.0
recall_score 0.036 0.0
percent_positive 0.016 0.25

Epoch 4
accuracy_score 0.712 0.75
precision_score 0.2 0.0
recall_score 0.071 0.0
percent_positive 0.025 0.25

Epoch 5
accuracy_score 0.714 0.75
precision_score 0.208 0.0
recall_score 0.08 0.0
percent_positive 0.026 0.25

Epoch 6
accuracy_score 0.71 0.75
precision_score 0.179 0.0
recall_score 0.068 0.0
percent_positive 0.022 0.25

Epoch 7
accuracy_score 0.707 0.75
precision_score 0.156 0.0
recall_score 0.06 0.0
percent_positive 0.02 0.25

Epoch 8
accuracy_score 0.708 0.625
precision_score 0.167 0.0
recall_score 0.062 0.0
percent_positive 0.021 0.2