In [106]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

import numpy as np
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib qt5
from IPython import display


def plot_image(n_img=4, batch_size=512):
    idx = np.random.randint(0, batch_size - 1, n_img)

    fig, ax = plt.subplots(1, n_img)
    for i in range(n_img):
        img = Image.fromarray(x[idx[i]].squeeze().numpy() * 255)
        ax[i].set_title("label %d" % y[idx[i]].data.item(), size=10, )
        ax[i].set_axis_off()
        ax[i].imshow(img)

    fig.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
    plt.show()

class Net(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = torch.nn.Linear(28 * 28, 256)
        self.fc2 = torch.nn.Linear(256, 64)
        self.fc3 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


if __name__ == '__main__':
    batch_size = 512
    train_loader = torch.utils.data.dataloader.DataLoader(
        torchvision.datasets.MNIST('mnist_data', train=True, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=batch_size, shuffle=True)

    valid_loader = torch.utils.data.dataloader.DataLoader(
        torchvision.datasets.MNIST('mnist_data', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=batch_size, shuffle=False)

    # x, y = next(iter(train_loader))
    # plot_image(10, batch_size)
    
    device = torch.device('cuda')

    net = Net()
    net.to(device)
    lr = 1e-2
    momentum = 0.9
    optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

    train_total_loss = []

    n_epoch = 10
    for epoch in range(n_epoch):
        total_loss = 0
        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(-1, 28 * 28)
            y_onehot = torch.nn.functional.one_hot(y, num_classes=10)
            y_onehot = y_onehot.type_as(x)
            # loss = mse(y - y_pred)
            loss = torch.nn.functional.mse_loss(net(x), y_onehot)
            total_loss += loss.item()
            # calculate grad
            optim.zero_grad()
            loss.backward()
            # w' = w - lr * grad
            optim.step()
        train_total_loss.append(total_loss)
    
    print(train_total_loss)



[8.177257902920246, 5.064475622028112, 4.094593511894345, 3.5508006773889065, 3.183746946975589, 2.9241901338100433, 2.7209360878914595, 2.5563995204865932, 2.422335173934698, 2.310547461733222]


In [109]:
for x, y in valid_loader:
    x = x.to(device)
    y = y.to(device)
    x = x.reshape(-1, 28 * 28)
    y_onehot = torch.nn.functional.one_hot(y, num_classes=10)
    y_onehot = y_onehot.type_as(x)
    loss = torch.nn.functional.mse_loss(net(x), y_onehot)
    print(loss)

tensor(0.0201, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0215, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0249, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0233, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0224, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0201, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0212, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0234, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0227, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0203, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0126, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0193, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0147, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0145, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0142, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0152, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.0138, device='cuda:0', grad_fn=

In [111]:
x, y = next(iter(valid_loader))
x = x.to(device)
y = y.to(device)
x = x.reshape(-1, 28 * 28)
y_onehot = torch.nn.functional.one_hot(y, num_classes=10)
y_onehot = y_onehot.type_as(x)
y_pred = torch.argmax(net(x), -1)
print(torch.argmax(y_onehot, -1))
print(y_pred)

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0,
        2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4,
        1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4, 8, 7, 3, 9, 7, 4, 4, 4, 9, 2,
        5, 4, 7, 6, 7, 9, 0, 5, 8, 5, 6, 6, 5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1,
        7, 1, 8, 2, 0, 2, 9, 9, 5, 5, 1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5,
        1, 4, 4, 7, 2, 3, 2, 7, 1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1,
        0, 9, 0, 3, 1, 6, 4, 2, 3, 6, 1, 1, 1, 3, 9, 5, 2, 9, 4, 5, 9, 3, 9, 0,
        3, 6, 5, 5, 7, 2, 2, 7, 1, 2, 8, 4, 1, 7, 3, 3, 8, 8, 7, 9, 2, 2, 4, 1,
        5, 9, 8, 7, 2, 3, 0, 4, 4, 2, 4, 1, 9, 5, 7, 7, 2, 8, 2, 6, 8, 5, 7, 7,
        9, 1, 8, 1, 8, 0, 3, 0, 1, 9, 9, 4, 1, 8, 2, 1, 2, 9, 7, 5, 9, 2, 6, 4,
        1, 5, 8, 2, 9, 2, 0, 4, 0, 0, 2,