In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score

In [2]:
print(sys.version_info)
for module in mpl, np, pd, sklearn, torch:
    print(module.__name__, module.__version__)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)
matplotlib 3.9.1
numpy 1.26.4
pandas 2.2.2
sklearn 1.5.1
torch 2.4.0+cu121
cuda:0


In [3]:
from torchvision import datasets
from torchvision.transforms import ToTensor

In [4]:
train_ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_ds = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [5]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

In [6]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(784, 300),
            nn.ReLU(),
            nn.Linear(300, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = NeuralNetwork()

In [7]:
from torch.utils.tensorboard import SummaryWriter

In [8]:
class TensorBoardCallback:
    def __init__(self, log_dir, flush_secs=10):
        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)

    def draw_model(self, model, input_shape):
        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))

    def add_loss_scalars(self, step, loss, val_loss):
        self.writer.add_scalar(
            tag='train/loss',
            scalar_value={'loss': loss, 'val_loss': val_loss},
            global_step=step
        )

    def add_acc_scalars(self, step, acc, val_acc):
        self.writer.add_scalar(
            tag='train/accuracy',
            scalar_value={'accuracy': acc, 'val_acc': val_acc},
            global_step=step
        )

    def add_lr_scalars(self, step, learning_rate):
        self.writer.add_scalar(
            tag='train/learning_rate',
            scalar_value={'learning_rate': learning_rate},
            global_step=step
        )

    def __call__(self, step, **kwargs):
        #loss曲线
        loss = kwargs.pop('loss', None)
        val_loss = kwargs.pop('val_loss', None)
        if loss is not None and val_loss is not None:
            self.add_loss_scalars(step, loss, val_loss)
        #acc曲线
        acc = kwargs.pop('acc', None)
        val_acc = kwargs.pop('val_acc', None)
        if acc is not None and val_acc is not None:
            self.add_acc_scalars(step, acc, val_acc)
        #lr曲线
        learning_rate = kwargs.pop('lr', None)
        if learning_rate is not None:
            self.add_lr_scalars(step, learning_rate)

In [9]:
class SaveCheckpointsCallback:
    def __init__(self, save_dir, save_step=5000, save_best_only=True):
        self.save_dir = save_dir
        self.save_step = save_step
        self.save_best_only = save_best_only
        self.best_metrics = -1

        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

    def __call__(self, step, state_dict, metric=None):
        if step % self.save_step != 0:
            return
        if self.save_best_only:
            assert metric is not None
            if metric >= self.best_metrics:
                self.best_metrics = metric
                torch.save(state_dict, os.path.join(self.save_dir, 'best.ckpt'))
        else:
            torch.save(state_dict, os.path.join(self.save_dir, f"{第step步}.ckpt"))

In [10]:
class EarlyStopCallback:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.best_metric = -1
        self.counter = 0

    def __call__(self, metric):
        if metric >= self.best_metric + self.min_delta:
            self.best_metric = metric
            self.counter = 0
        else:
            self.counter += 1

    @property
    def early_stop(self):
        return self.counter > self.patience

In [11]:
def training(model, train_loader, val_loader,
             epoch, optmizer, loss_fct,
             eval_step=500, tensorbard_callback=None, save_ckpt_callback=None,
             early_stop_callback=None):
    record_dict = {'train': [], 'val': []}
    global_step = 0
    #开启训练模式
    model.train()
    with tqdm(total=epoch * len(train_loader)) as pbar:
        for epoch_id in range(epoch):
            for datas, labels in train_loader:
                datas = datas.to(device)
                labels = labels.to(device)
                optmizer.zero_grad()
                logits = model(datas)
                loss = loss_fct(logits, labels)
                #反向传播
                loss.backward()
                optmizer.step()
                #记录
                preds = logits.argmax(axis=-1)
                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
                loss = loss.cpu().item()
                record_dict['train'].append({
                    'accuracy': acc,
                    'loss': loss,
                    'step': global_step
                })
                #验证
                if global_step % eval_step == 0:
                    model.eval()
                    val_loss, val_acc = evaluating(model, val_loader, loss_fct)
                    record_dict['val'].append({
                        'loss': val_loss,
                        'accuracy': acc,
                        'step': global_step
                    })
                    model.train()
                    #可视化
                    if tensorbard_callback is not None:
                        tensorbard_callback(global_step,
                                            loss=loss, val_loss=val_loss, acc=acc, val_acc=val_acc,
                                            lr=optmizer.param_groups[0]["lr"])
                    #模型保存
                    if save_ckpt_callback is not None:
                        save_ckpt_callback(global_step,
                                           model.state_dict(),
                                           metric=val_acc
                                           )
                    #早停
                    if early_stop_callback is not None:
                        if early_stop_callback(val_acc) is True:
                            print(f"Early stop at epoch {epoch_id}/global_step{global_step}")
                            return record_dict

                global_step += 1
                pbar.update(1)
                pbar.set_postfix({'epoch': epoch_id})
    return record_dict

In [12]:
@torch.no_grad()
def evaluating(model, dataloader, loss_fct):
    loss_list = []
    pred_list = []
    label_list = []
    for datas, labels in dataloader:
        datas = datas.to(device)
        labels = labels.to(device)
        logits = model(datas)
        loss = loss_fct(logits, labels)
        preds = logits.argmax(axis=-1)
        loss_list.append(loss.item())
        pred_list.extend(preds.cpu().numpy().tolist())
        label_list.extend(labels.cpu().numpy().tolist())
    acc = accuracy_score(label_list, pred_list)
    return np.mean(loss_list), acc

In [13]:
loss_fct = nn.CrossEntropyLoss()
optmizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
tensorboard_callback = TensorBoardCallback("runs")
tensorboard_callback.draw_model(model, [1, 28, 28])
save_ckpt_callback = SaveCheckpointsCallback("checkpoints", save_best_only=True)
early_stop_callback = EarlyStopCallback(patience=10)
epoch = 100
#开始训练
model = model.to(device)
record = training(model, train_loader, val_loader, epoch, optmizer, loss_fct,
                  tensorbard_callback=tensorboard_callback,
                  save_ckpt_callback=save_ckpt_callback,
                  early_stop_callback=early_stop_callback,
                  eval_step=1000
                  )

  0%|          | 0/187500 [00:00<?, ?it/s]

NotImplementedError: Got <class 'dict'>, but numpy array or torch tensor are expected.