In [1]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn

In [2]:
if torch.cuda.is_available():
    device=torch.device(type="cuda",index=0)
else:
    device=torch.device(type="cpu",index=0)

In [3]:
train_dataset=datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_dataset=datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 220643909.61it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 73638719.65it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 61118743.29it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11862097.61it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [4]:
batch_size=64

train_dl=DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_dl=DataLoader(
    dataset=test_dataset,
    batch_size=batch_size, 
    shuffle=True
)

In [5]:
class MNISTNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.inh1=nn.Linear(in_features=784, out_features=512)
        self.relu=nn.ReLU()
        self.bn1=nn.BatchNorm1d(num_features=512)
        self.h2=nn.Linear(in_features=512, out_features=256)
        self.bn2=nn.BatchNorm1d(num_features=256)
        self.h3=nn.Linear(in_features=256, out_features=128)
        self.bn3=nn.BatchNorm1d(num_features=128)
        self.h4=nn.Linear(in_features=128, out_features=64)
        self.bn4=nn.BatchNorm1d(num_features=64)
        self.h5=nn.Linear(in_features=64, out_features=32)
        self.bn5=nn.BatchNorm1d(num_features=32)
        self.output=nn.Linear(in_features=32, out_features=10)
        self.bn6=nn.BatchNorm1d(num_features=10)
        
    def forward(self,x):
        x=self.inh1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.h2(x)
        x=self.bn2(x)
        x=self.relu(x)
        x=self.h3(x)
        x=self.bn3(x)
        x=self.relu(x)
        x=self.h4(x)
        x=self.bn4(x)
        x=self.relu(x)
        x=self.h5(x)
        x=self.bn5(x)
        x=self.relu(x)
        x=self.output(x)
        output=self.bn6(x)
        return output
    
#using batch normalization before relu: Reduces internal covariate shif, Faster convergence, help avoid vanishing or exploding gradients
#after relu: Sometimes gives better accuracy by allowing the network to learn from negative activations.

def train_one_epoch(dataloader, model,loss_fn, optimizer):
    model.train()
    track_loss=0
    num_correct=0
    for i, (imgs, labels) in enumerate(dataloader):
        imgs=torch.reshape(imgs,shape=[-1,784]).to(device)
        labels=labels.to(device)
        pred=model(imgs)
        loss=loss_fn(pred,labels)
        track_loss+=loss.item()
        num_correct+=(torch.argmax(pred,dim=1)==labels).type(torch.float).sum().item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i%100==0:
            running_loss=round(track_loss/(i+1),2)
            running_acc=round((num_correct/((i+1)*batch_size))*100,2)
            print("Batch:", i+1, "/",len(dataloader), "Running Loss:",running_loss, "Running Accuracy:",running_acc)
            
    epoch_loss=track_loss/len(dataloader) #can be slightly inaccurate
    epoch_acc=(num_correct/len(dataloader.dataset))*100    
    return round(epoch_loss,2), round(epoch_acc,2)

def eval_one_epoch(dataloader, model,loss_fn):
    model.eval()
    track_loss=0
    num_correct=0
    with torch.no_grad():
        for i, (imgs, labels) in enumerate(dataloader):
            imgs=torch.reshape(imgs,shape=[-1,784]).to(device)
            labels=labels.to(device)
            pred=model(imgs)
            loss=loss_fn(pred,labels)
            track_loss+=loss.item()
            num_correct+=(torch.argmax(pred,dim=1)==labels).type(torch.float).sum().item()

            if i%100==0:
                running_loss=round(track_loss/(i+1),2)
                running_acc=round((num_correct/((i+1)*batch_size))*100,2)
                print("Batch:", i+1, "/",len(dataloader), "Running Loss:",running_loss, "Running Accuracy:",running_acc)

    epoch_loss=track_loss/len(dataloader) #can be slightly inaccurate
    epoch_acc=(num_correct/len(dataloader.dataset))*100    
    return round(epoch_loss,2), round(epoch_acc,2)

model=MNISTNN()
model=model.to(device)
loss_fn=nn.CrossEntropyLoss()
lr=0.001
#optimizer=torch.optim.SGD(params=model.parameters(), lr=lr)
optimizer=torch.optim.Adam(params=model.parameters(), lr=lr)
n_epochs=30

for i in range(n_epochs):
    print("Epoch No:",i+1)
    train_epoch_loss, train_epoch_acc=train_one_epoch(train_dl,model,loss_fn,optimizer)
    val_epoch_loss, val_epoch_acc=eval_one_epoch(test_dl,model,loss_fn)
    print("Training:", "Epoch Loss:", train_epoch_loss, "Epoch Accuracy:", train_epoch_acc)
    print("Inference:", "Epoch Loss:", val_epoch_loss, "Epoch Accuracy:", val_epoch_acc)
    print("--------------------------------------------------")


Epoch No: 1

Batch: 1 / 938 Running Loss: 2.68 Running Accuracy: 9.38

Batch: 101 / 938 Running Loss: 0.78 Running Accuracy: 83.88

Batch: 201 / 938 Running Loss: 0.62 Running Accuracy: 88.5

Batch: 301 / 938 Running Loss: 0.55 Running Accuracy: 90.05

Batch: 401 / 938 Running Loss: 0.49 Running Accuracy: 91.13

Batch: 501 / 938 Running Loss: 0.46 Running Accuracy: 91.74

Batch: 601 / 938 Running Loss: 0.43 Running Accuracy: 92.24

Batch: 701 / 938 Running Loss: 0.41 Running Accuracy: 92.61

Batch: 801 / 938 Running Loss: 0.39 Running Accuracy: 92.84

Batch: 901 / 938 Running Loss: 0.37 Running Accuracy: 93.18

Batch: 1 / 157 Running Loss: 0.2 Running Accuracy: 93.75

Batch: 101 / 157 Running Loss: 0.17 Running Accuracy: 96.67

Training: Epoch Loss: 0.36 Epoch Accuracy: 93.27

Inference: Epoch Loss: 0.17 Epoch Accuracy: 96.64

--------------------------------------------------

Epoch No: 2

Batch: 1 / 938 Running Loss: 0.12 Running Accuracy: 98.44

Batch: 101 / 938 Running Loss: 0.19 R