In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# self-implement MLP with 2-layers
class mlp():
    def __init__(self, input_size, hidden_size, out_size):
        w1 = torch.normal(0, 0.01, size=[input_size, hidden_size], requires_grad=True)
        b1 = torch.zeros(size=[hidden_size], requires_grad=True)
        w2 = torch.normal(0, 0.01, size=[hidden_size, out_size], requires_grad=True)
        b2 = torch.zeros(size=[out_size], requires_grad=True)
        self.params = [w1, b1, w2, b2]
    
    def relu(self, x):
        zero = torch.zeros_like(x)
        return torch.maximum(zero, x)
    
    def __call__(self, x):
        o = self.relu(torch.matmul(x, self.params[0]) + self.params[1])
        o = torch.matmul(o, self.params[2]) + self.params[3]
        return o

    def parameters(self):
        return self.params

# self-implement sgd
class sgd():
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr
    
    def step(self):
        with torch.no_grad():
            for param in self.params:
                param -= self.lr * param.grad
    
    def zero_grad(self):
        for param in self.params:
            param.grad.data.zero_()

# self-implement nn.CrossEntropy
class softmax_and_crossentropy():
    # y_hat is size=[m*d] (non-softmaxed), y is size[d]
    # - log (exp(y^) / sum(exp(y)))
    def __call__(self, y_hat, y):
        m = torch.max(y_hat, dim=1, keepdim=True).values
        out = -(y_hat[list(range(len(y_hat))), y] - m.flatten()) + torch.log( torch.sum(torch.exp(y_hat-m), dim=1) )
        return torch.mean(out)

def accuracy(model, data_iter):
    correct_cnt = 0
    total_cnt = 0
    total_loss = 0
    if isinstance(model, nn.Module):
        model.eval()

    with torch.no_grad():
        for feature, label in data_iter:
            feature = feature.reshape(feature.shape[0], -1)
            y_pre = model(feature)
            loss = criteria(y_pre, label)

            total_loss += loss.item() * len(y_pre)
            correct_cnt += (y_pre.argmax(dim=1) == label).sum().item()
            total_cnt += len(y_pre)
    return correct_cnt / total_cnt, total_loss / total_cnt

n_epochs = 10
lr = 0.01
batch_size = 1000
img_size = 28*28
hidden_size = 256
out_size = 10

model = mlp(img_size, hidden_size, out_size)
criteria = softmax_and_crossentropy()
optimizer = sgd(model.parameters(), lr)

#model = nn.Sequential(nn.Linear(img_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, out_size))
#def init_weights(m):
#  if type(m) == nn.Linear:
#      nn.init.normal_(m.weight, 0, 0.01)
#      nn.init.zeros_(m.bias)
#model.apply(init_weights)
#criteria = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr)

train_iter = data.DataLoader(torchvision.datasets.MNIST(root="./", train=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)
test_iter = data.DataLoader(torchvision.datasets.MNIST(root="./", train=False, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=False)


for n in range(n_epochs):
    for feature, label in train_iter:
        feature = feature.reshape(feature.shape[0], -1)
        y_pre = model(feature)
        loss = criteria(y_pre, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    print('epoch', n, accuracy(model, train_iter), accuracy(model, test_iter))


epoch 0 (0.4172, 2.2891356348991394) (0.4223, 2.2888841152191164)
epoch 1 (0.47928333333333334, 2.2711361090342206) (0.4818, 2.2704803943634033)
epoch 2 (0.48486666666666667, 2.2440789341926575) (0.4888, 2.2428484678268434)
epoch 3 (0.4822, 2.2028356909751894) (0.4906, 2.200791025161743)
epoch 4 (0.50495, 2.1419221639633177) (0.5135, 2.1386983394622803)
epoch 5 (0.55045, 2.056370504697164) (0.5592, 2.0514302015304566)
epoch 6 (0.6012, 1.9432695309321086) (0.605, 1.9360023736953735)
epoch 7 (0.6458833333333334, 1.803937671581904) (0.6525, 1.7937395811080932)
epoch 8 (0.6816666666666666, 1.6466139614582063) (0.6858, 1.6333268523216247)
epoch 9 (0.7078166666666666, 1.485437693198522) (0.7123, 1.4695473074913026)


In [25]:
torch.arange(12).numel()

12