<a href="https://colab.research.google.com/github/karan-dalal/basicneuralnet/blob/main/3_Neural_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms as T

dtype = torch.float32
print_every = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def flatten(x):
    N = x.shape[0]
    return x.view(N, -1)  

def check_accuracy_part34(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

def train_part34(model, optimizer, loader_train, loader_val, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API and prints model 
    accuracies during training.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - loader_train: Dataloader for training
    - loader_val: Dataloader for evaluation
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Lists of validation accuracies at the end of each epoch.
    """
    loss_fn = nn.CrossEntropyLoss()
    model = model.to(device=device)
    train_accs = []
    val_accs = []
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = loss_fn(scores, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if t % print_every == 0:
                print('Iteration %d, loss = %.4f' % (t, loss.item()))
                check_accuracy_part34(loader_val, model)
                print()
        val_accs.append(check_accuracy_part34(loader_val, model))
    return val_accs

In [None]:
ds_train = datasets.MNIST('.', download = True, train = True, transform = T.ToTensor())
loader_train = DataLoader(ds_train, batch_size = 32)
ds_val = datasets.MNIST('.', download = True, train = False, transform = T.ToTensor())
loader_val = DataLoader(ds_val)

class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(28 * 28, 256)
    self.fc2 = nn.Linear(256, 10)
    self.relu = nn.ReLU()


  def forward(self, x):
    x = flatten(x)
    return self.fc2(self.relu(self.fc1(x)))

model = Network()
optimizer = optim.Adam(model.parameters())
train_part34(model, optimizer, loader_train, loader_val)

Iteration 0, loss = 2.3157
Checking accuracy on test set
Got 1327 / 10000 correct (13.27)

Iteration 100, loss = 0.4907
Checking accuracy on test set
Got 8773 / 10000 correct (87.73)

Iteration 200, loss = 0.3601
Checking accuracy on test set
Got 8895 / 10000 correct (88.95)

Iteration 300, loss = 0.2262
Checking accuracy on test set
Got 8992 / 10000 correct (89.92)

Iteration 400, loss = 0.1882
Checking accuracy on test set
Got 9033 / 10000 correct (90.33)

Iteration 500, loss = 0.3959
Checking accuracy on test set
Got 9292 / 10000 correct (92.92)

Iteration 600, loss = 0.1790
Checking accuracy on test set
Got 9373 / 10000 correct (93.73)

Iteration 700, loss = 0.1452
Checking accuracy on test set
Got 9371 / 10000 correct (93.71)

Iteration 800, loss = 0.1414
Checking accuracy on test set
Got 9422 / 10000 correct (94.22)

Iteration 900, loss = 0.1476
Checking accuracy on test set
Got 9434 / 10000 correct (94.34)

Iteration 1000, loss = 0.4404
Checking accuracy on test set
Got 9442 / 1

[0.9577]