In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import v2
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Config:
    def __init__(self):
        self.device = torch.device(
            "cuda"
            if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available() else "cpu"
        )


def getDataLoaders(batch_size):
    transform = v2.Compose(
        [
            v2.Resize((32, 32)),
            v2.ToTensor(),
            v2.Normalize(
                [
                    0.5,
                ],
                [
                    0.5,
                ],
            ),
        ]
    )

    train_dataset = datasets.FashionMNIST(
        root="../data", train=True, download=True, transform=transform
    )
    test_dataset = datasets.FashionMNIST(
        root="../data", train=False, download=True, transform=transform
    )
    train_loader = DataLoader(
        datasets=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
    )
    test_dataset = DataLoader(
        datasets=test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )
    return train_loader, test_dataset


class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            [
                nn.Conv2d(1, 6, kernel_size=5),  # input:1*28*28 output:6*24*24
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # input:6*24*24 output:6*12*12
                nn.Conv2d(6, 16, kernel_size=5),  # input:6*12*12 output:16*8*8
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2),  # input:16*8*8 output:16*4*4
            ]
        )
        self.classifier = nn.Sequential(
            [
                nn.Linear(4 * 4 * 16, 120),
                nn.ReLU(),
                nn.Linear(120, 84),
                nn.ReLU(),
                nn.Linear(84, num_classes),
            ]
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class Visualizer:
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir=log_dir)
        self.step = 0

    def log_scalars(
        self, train_loss=None, test_loss=None, train_acc=None, test_acc=None
    ):
        """添加标量数据到 TensorBoard"""
        if train_loss is not None:
            self.writer.add_scalar("Loss/train", train_loss, self.step)
        if test_loss is not None:
            self.writer.add_scalar("Loss/test", test_loss, self.step)
        if train_acc is not None:
            self.writer.add_scalar("Accuracy/train", train_acc, self.step)
        if test_acc is not None:
            self.writer.add_scalar("Accuracy/test", test_acc, self.step)

    def log_images(self, images, tag="Trainning_images"):
        """添加图像数据到 TensorBoard"""
        img_grid_ = torchvision.utils.make_grid(images)
        self.writer.add_image(tag, img_grid_, self.step)

    def log_model_graph(self, model, dummy_input):
        """添加模型结构图到 TensorBoard"""
        self.writer.add_graph(model, dummy_input)

    def close(self):
        self.writer.close()

    def increment_step(self):
        self.step += 1