<a href="https://colab.research.google.com/github/niikun/ml_duke_univ/blob/main/MultiLayer_Perception.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Load the data
mnist_train = datasets.MNIST(root="./datasets", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root="./datasets", train=False, transform=transforms.ToTensor(), download=True)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 6974723.20it/s]


Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 466658.55it/s]


Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4233222.26it/s]


Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3745680.06it/s]


Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw



In [2]:
torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=64,shuffle=False)

In [12]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28,64)
        self.fc2 = nn.Linear(64,64)
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        x = x.view(-1,28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()

In [16]:
import torch.optim as optim

max_epochs = 10
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.001)


In [18]:
train_losses = []
train_accs = []
test_losses = []
test_accs = []

for epoch in range(max_epochs):
    sum_train_losses = 0
    sum_train_acc = 0
    sum_test_losses = 0
    sum_test_acc = 0
    net.train()
    for images,labels in tqdm(train_loader):
        optimizer.zero_grad()
        output = net(images)
        loss = loss_func(output,labels)
        loss.backward()
        optimizer.step()
        sum_train_losses += loss.item()
        sum_train_acc += (output.argmax(dim=1) == labels).float().mean().item()
        train_losses.append(loss)

        train_accs.append(sum_train_acc/len(train_loader))
    print(f"{epoch}/{max_epochs} train_loss: {sum_train_losses/len(train_loader)}, train_acc: {sum_train_acc/len(train_loader)}")
    net.eval()
    with torch.no_grad():
        for images,labels in tqdm(test_loader):
            output = net(images)
            loss = loss_func(output,labels)
            test_losses.append(loss)
            test_acc = (output.argmax(dim=1) == labels).float().mean()

            sum_test_losses += loss.item()
            sum_test_acc += (output.argmax(dim=1) == labels).float().mean().item()
            test_losses.append(loss)

            test_accs.append(sum_test_acc/len(train_loader))
    print(f"{epoch}/{max_epochs} test_loss: {sum_test_losses/len(test_loader)}, test_acc: {sum_test_acc/len(test_loader)}")


  0%|          | 0/938 [00:00<?, ?it/s]

0/10 train_loss: 0.23017486245996915, train_acc: 0.9316697761194029


  0%|          | 0/157 [00:00<?, ?it/s]

0/10 test_loss: 0.1680823552190878, test_acc: 0.9501393312101911


  0%|          | 0/938 [00:00<?, ?it/s]

1/10 train_loss: 0.15058675575365962, train_acc: 0.955023987206823


  0%|          | 0/157 [00:00<?, ?it/s]

1/10 test_loss: 0.14247035061062616, test_acc: 0.956906847133758


  0%|          | 0/938 [00:00<?, ?it/s]

2/10 train_loss: 0.11359125942285699, train_acc: 0.9652018923240938


  0%|          | 0/157 [00:00<?, ?it/s]

2/10 test_loss: 0.10667265250397977, test_acc: 0.9673566878980892


  0%|          | 0/938 [00:00<?, ?it/s]

3/10 train_loss: 0.09114742544102374, train_acc: 0.9716484541577826


  0%|          | 0/157 [00:00<?, ?it/s]

3/10 test_loss: 0.10065124090316047, test_acc: 0.9689490445859873


  0%|          | 0/938 [00:00<?, ?it/s]

4/10 train_loss: 0.07572412180147374, train_acc: 0.9769123134328358


  0%|          | 0/157 [00:00<?, ?it/s]

4/10 test_loss: 0.09278455411197446, test_acc: 0.9720342356687898


  0%|          | 0/938 [00:00<?, ?it/s]

5/10 train_loss: 0.06284804037436525, train_acc: 0.9807269456289979


  0%|          | 0/157 [00:00<?, ?it/s]

5/10 test_loss: 0.09694000017391013, test_acc: 0.9708399681528662


  0%|          | 0/938 [00:00<?, ?it/s]

6/10 train_loss: 0.053694958340869084, train_acc: 0.9833422174840085


  0%|          | 0/157 [00:00<?, ?it/s]

6/10 test_loss: 0.08655598206011056, test_acc: 0.9743232484076433


  0%|          | 0/938 [00:00<?, ?it/s]

7/10 train_loss: 0.04566427750283602, train_acc: 0.9854077825159915


  0%|          | 0/157 [00:00<?, ?it/s]

7/10 test_loss: 0.09099880167534115, test_acc: 0.9729299363057324


  0%|          | 0/938 [00:00<?, ?it/s]

8/10 train_loss: 0.03854272852440539, train_acc: 0.9880563699360341


  0%|          | 0/157 [00:00<?, ?it/s]

8/10 test_loss: 0.09969880418423015, test_acc: 0.9730294585987261


  0%|          | 0/938 [00:00<?, ?it/s]

9/10 train_loss: 0.034802126100481445, train_acc: 0.9886227345415778


  0%|          | 0/157 [00:00<?, ?it/s]

9/10 test_loss: 0.10201073540751875, test_acc: 0.9715366242038217


In [20]:
28*28*500 + 500 + 500*10 + 10

397510