In [None]:
# 2021/12/26
# keyword: mnist, dataloader, datasets, transforms, matplotlib
# mnist 全连接拟合
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import numpy as np
import torch
from torch.nn import Module, Linear
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data.dataloader import DataLoader
from torch.optim import SGD
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib qt5
from IPython import display


def plot_image(x, n_img=4, batch_size=512):
    idx = np.random.randint(0, batch_size - 1, n_img)

    fig, ax = plt.subplots(1, n_img)
    for i in range(n_img):
        img = Image.fromarray(x[idx[i]].squeeze().numpy() * 255)
        ax[i].set_title("label %d" % y[idx[i]].data.item(), size=10)
        ax[i].set_axis_off()
        ax[i].imshow(img)

    fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
    plt.show()

class Net(Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = Linear(28 * 28, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


if __name__ == '__main__':
    batch_size = 512
    train_loader = DataLoader(
        MNIST('data', train=True, download=True, transform=Compose([
            ToTensor(),
            Normalize((0.1307,), (0.3081,))
            ])), batch_size=batch_size, shuffle=True)

    valid_loader = DataLoader(
        MNIST('data', train=False, download=True, transform=Compose([
            ToTensor(),
            Normalize((0.1307,), (0.3081,))
            ])), batch_size=batch_size, shuffle=False)

    # x, y = next(iter(train_loader))
    # plot_image(x, 10, batch_size)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    net = Net().to(device)

    lr = 1e-2
    momentum = 0.9
    optim = SGD(net.parameters(), lr=lr, momentum=momentum)

    train_total_loss_hist = []
    train_loss_hist = []
    valid_total_loss_hist = []
    
    fig = plt.figure("Mnist (Dynamic)", figsize=(12, 6))
    plt.ion()

    n_epoch = 10
    display_per_batch = 10
    delay_interval = 0.1
    for i_epoch in range(n_epoch):
        # train
        train_total_loss = 0
        for i, (x_train, y_train) in enumerate(train_loader):
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            x_train = x_train.reshape(-1, 28 * 28)
            y_train_onehot = F.one_hot(y_train, num_classes=10).type_as(x_train)

            # batch reset grads
            optim.zero_grad()

            # logits forward, calculate loss, loss = mse(y - y_pred)
            train_loss = F.mse_loss(y_train_onehot, net(x_train))

            train_loss_hist.append(train_loss.item())
            train_total_loss += train_loss.item()

            # calculate new grads, update net parameters
            train_loss.backward()
            optim.step()    # w' = w - lr * grad
            
            # batch loss visualization
            if i % display_per_batch == 0 or i == len(train_loader) - 1:
                display.clear_output(wait=True)

                ax0 = plt.subplot(121)
                plt.cla()
                # ax0 = fig.add_subplot(211)
                ax0.set_title("epoch [%d/%d], batch [%d/%d], lr = %.5f" % (i_epoch + 1, n_epoch, i + 1, len(train_loader), lr), size=10)
                ax0.set_xlim(0., n_epoch * len(train_loader))
                ax0.set_ylim(0., .15)
                fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
                ax0.plot(torch.arange(0, len(train_loss_hist)), train_loss_hist, color='b')
                plt.pause(delay_interval)

        train_total_loss_hist.append(train_total_loss)

        # valid
        valid_total_loss = 0
        for i, (x_valid, y_valid) in enumerate(valid_loader):
            x_valid = x_valid.to(device)
            y_valid = y_valid.to(device)
            x_valid = x_valid.reshape(-1, 28 * 28)
            y_valid_onehot = F.one_hot(y_valid, num_classes=10).type_as(x_valid)

            valid_loss = F.mse_loss(y_valid_onehot, net(x_valid))
            valid_total_loss += valid_loss.item()

        valid_total_loss_hist.append(valid_total_loss)

        # train/valid total loss visualization
        display.clear_output(wait=True)
        ax1 = plt.subplot(122)
        plt.cla()
        ax1.set_title("train/valid total loss", size=10)
        ax1.set_xlim(0, n_epoch - 1)
        ax1.set_ylim(0, 10.)
        fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
        ax1.plot(torch.arange(0, len(train_total_loss_hist)), train_total_loss_hist, color='r', label='train total loss')
        ax1.plot(torch.arange(0, len(valid_total_loss_hist)), valid_total_loss_hist, color='g', label='valid total loss')
        ax1.legend(loc='upper right')
        plt.pause(delay_interval)
        
    plt.ioff()
    plt.show()


In [None]:
x, y = next(iter(valid_loader))
x = x.to(device)
y = y.to(device)
x = x.reshape(-1, 28 * 28)
y_onehot = F.one_hot(y, num_classes=10)
y_onehot = y_onehot.type_as(x)
y_pred = torch.argmax(net(x), -1)
print(torch.argmax(y_onehot, -1))
print(y_pred)