In [32]:
#linear regression is a supervised regression model, but logistic regression is a supervised classification model
#predicting the value of a certain image - classifying it based on training data

import torch
#The torchvision package consists of popular datasets, model architectures, and common image transformations 
#for computer vision.
import torchvision
#import the MNIST dataset, which has all of the images - very popular, like the iris dataset
from torchvision.datasets import MNIST
import numpy
import torchvision.transforms as transforms
import numpy as np

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F

In [11]:
#training set and testing set
dataset = MNIST(root="data/", train=True, transform = transforms.ToTensor())
test_dataset = MNIST(root="data/", train=False, transform = transforms.ToTensor())

In [12]:
print(dataset)
print(test_dataset)

Dataset MNIST
    Number of datapoints: 60000
    Root location: data/
    Split: Train
    StandardTransform
Transform: ToTensor()
Dataset MNIST
    Number of datapoints: 10000
    Root location: data/
    Split: Test
    StandardTransform
Transform: ToTensor()


In [22]:
#validation set
perms = (np.random.permutation(len(dataset)))
san_per = 0.2
indices = int(len(dataset)*san_per)
print(indices)
train_indices = perms[indices:]
val_indices = perms[:indices]

12000
12000


In [26]:
#training dataloader
train_sampler = SubsetRandomSampler(train_indices)
train_dl = DataLoader(dataset, 100, sampler=train_sampler)

In [30]:
#validation dataloader
val_sampler = SubsetRandomSampler(val_indices)
val_dl = DataLoader(dataset, 100, sampler=val_sampler)

In [35]:
#model

#the forward function just computes the outputs based on the inputs, or xb in this case
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        #the reason for these inputs and outputs are because we want to flatten out the image, and then also 
        #the outputs are for the probabilites of the item being from 0 - 9
        self.linear = nn.Linear(28*28, 10)
        
    def forward(self,xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
        

In [36]:
model = MnistModel()

loss_fn = F.cross_entropy


In [37]:
optimizer = torch.optim.SGD(model.parameters(), 1e-3)

In [38]:
def loss_on_batch(loss_fn, model, xb, yb, opt):
    outputs = model(xb)
    loss = loss_fn(outputs, yb)
    
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    

In [60]:
def accuracy(outputs, labels):
    probs, preds = torch.max(outputs, dim=1)
    return torch.sum(preds == labels).item()/len(labels)
    

In [42]:
#model training

def train(epochs, loss_fn, model,opt, train_dl):
    for epoch in range(epochs):
        for xb, yb in train_dl:
            loss_on_batch(loss_fn, model, xb, yb, opt)

In [43]:
train(5, loss_fn, model, optimizer, train_dl)

In [44]:
test_dataset = MNIST(root = 'data/', train = False, transform = transforms.ToTensor())
def predict_image(img, model, label, acc_arr):
    xb = img.unsqueeze(0)
    yb = model(xb)
    prob, preds = torch.max(yb, dim = 1 )
#     print(label, preds[0].item())
    if label == preds[0].item():
        print(True)
        acc_arr.append(1)
        
    else:
        print(label, preds[0].item())
        print(False)
        acc_arr.append(0)

In [59]:
avg_acc = []
def validate(valid_dl, model):
    for xb, yb in valid_dl:
        preds = model(xb)
#         print(accuracy(preds, yb))
        avg_acc.append(accuracy(preds,yb))
    
    return sum(avg_acc)/len(avg_acc)

#after running the model through the validation dataset, this was the average accuracy
print(validate(val_dl, model))

tensor([8, 4, 6, 3, 3, 1, 0, 9, 0, 1, 4, 9, 5, 9, 7, 5, 9, 8, 8, 7, 8, 6, 9, 1,
        3, 7, 0, 2, 7, 0, 7, 4, 1, 2, 1, 0, 2, 7, 8, 2, 3, 5, 1, 4, 8, 1, 5, 0,
        3, 6, 4, 3, 1, 5, 7, 4, 1, 1, 1, 5, 6, 1, 6, 9, 7, 1, 3, 9, 1, 1, 3, 8,
        9, 8, 3, 4, 9, 9, 1, 2, 8, 1, 5, 9, 0, 0, 9, 6, 1, 2, 1, 3, 2, 8, 1, 0,
        1, 1, 0, 4])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True, False,  True,  True,  True,  True,  True,  True, False,  True,
         True, False,  True,  True,  True,  True, False,  True,  True,  True,
        False,  True,  True,  True,  True,  True, False,  True, False,  True,
        False,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True, False,  True,  True,  True,  True,  True, False, False,  True,
         True, False,  True,  True,  True,  True, False,  True,  True,  True,
         True, False,  True, False,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,

        0, 3, 2, 2])
tensor([ True,  True,  True,  True,  True,  True,  True,  True, False,  True,
        False,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True, False, False, False, False,  True,  True,  True, False,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True, False, False, False,  True, False,  True,  True,  True,  True,
         True, False,  True, False,  True,  True,  True, False, False,  True,
         True, False,  True, False,  True,  True, False,  True,  True,  True,
        False,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False, False,  True, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True])
tensor([7, 3, 3, 8, 4, 6, 7, 7, 8, 1, 1, 1, 7, 7, 7, 9, 0, 0, 5, 8, 2, 7, 2, 1,
        6, 5, 1, 7, 7, 1, 3, 7, 6, 0, 8, 2, 9, 3, 3, 9, 7, 7, 9, 0, 1, 4, 3, 7,
        7, 6, 8, 0, 0, 4, 7, 7, 0, 2, 

tensor([7, 7, 5, 6, 1, 7, 0, 6, 1, 1, 8, 7, 4, 6, 6, 9, 5, 3, 8, 1, 1, 6, 5, 9,
        9, 5, 9, 2, 5, 7, 3, 7, 6, 0, 5, 8, 8, 3, 8, 2, 9, 2, 9, 6, 0, 6, 7, 3,
        3, 6, 3, 1, 2, 1, 3, 9, 2, 3, 6, 6, 0, 4, 7, 7, 8, 6, 6, 2, 2, 9, 1, 3,
        7, 1, 2, 1, 3, 9, 8, 4, 2, 5, 6, 3, 5, 8, 3, 6, 1, 6, 0, 7, 0, 2, 4, 6,
        9, 9, 8, 6])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True, False,  True,  True, False, False, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True, False,  True,  True,  True, False, False,
         True,  True,  True,  True,  True, False,  True,  True,  True,  True,
        False, False,  True, False,  True,  True,  True, False,  True, False,
         True,  True,  True,  True,

tensor([8, 0, 2, 8, 0, 6, 8, 4, 8, 5, 8, 9, 4, 9, 1, 1, 7, 9, 6, 8, 8, 5, 8, 1,
        4, 8, 0, 6, 5, 3, 2, 9, 3, 3, 8, 7, 2, 4, 8, 0, 7, 9, 1, 0, 7, 3, 3, 1,
        1, 3, 8, 6, 0, 0, 6, 2, 7, 0, 4, 2, 4, 8, 2, 3, 6, 7, 8, 7, 9, 0, 2, 8,
        6, 3, 2, 4, 1, 0, 5, 5, 1, 6, 7, 4, 1, 5, 1, 2, 9, 7, 6, 4, 3, 9, 6, 4,
        5, 7, 6, 0])
tensor([False,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True, False,  True,  True,  True, False,
        False, False, False, False, False,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
        False,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False, False,  True,
         True,  True,  True,  True,

tensor([1, 2, 2, 4, 0, 8, 5, 7, 2, 1, 7, 1, 8, 5, 3, 1, 6, 4, 2, 5, 3, 2, 9, 1,
        1, 9, 2, 1, 7, 4, 1, 2, 3, 7, 0, 8, 8, 1, 2, 6, 9, 8, 1, 5, 9, 3, 5, 8,
        7, 6, 7, 0, 6, 8, 0, 3, 6, 7, 0, 6, 6, 6, 9, 3, 9, 6, 1, 1, 9, 6, 1, 6,
        9, 1, 6, 6, 7, 3, 5, 7, 1, 6, 1, 0, 3, 7, 3, 8, 7, 5, 7, 1, 0, 2, 3, 8,
        7, 2, 9, 2])
tensor([ True, False,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True, False,  True,  True, False, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False, False,  True, False,  True,  True,  True,  True,
         True, False,  True,  True,  True, False,  True,  True,  True,  True,
        False,  True,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True, False,
         True,  True,  True, False, False,  True,  True,  True,  True, False,
        False,  True,  True,  True,

tensor([ True,  True, False,  True, False,  True,  True, False,  True,  True,
         True,  True,  True, False,  True, False,  True,  True,  True,  True,
         True, False,  True, False,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True, False,  True,  True, False,  True, False,
         True, False,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True, False, False,  True, False,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True, False, False,  True,  True,  True,  True,  True,  True,  True,
        False,  True, False,  True,  True, False,  True, False,  True, False])
tensor([4, 8, 4, 9, 7, 8, 5, 4, 2, 4, 2, 6, 1, 1, 9, 0, 8, 6, 9, 8, 8, 2, 1, 8,
        9, 0, 3, 1, 1, 3, 3, 8, 7, 9, 9, 1, 3, 9, 9, 8, 9, 2, 8, 2, 0, 4, 6, 6,
        1, 9, 1, 4, 1, 9, 6, 2, 7, 2, 6, 1, 1, 9, 8, 6, 0, 

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True,  True, False,  True, False,  True, False,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True, False,
        False,  True,  True,  True,  True,  True, False,  True,  True, False,
         True,  True,  True,  True,  True, False,  True,  True,  True, False,
        False,  True,  True,  True, False,  True,  True,  True, False,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
        False, False, False,  True,  True,  True,  True,  True,  True,  True,
         True, False, False,  True,  True,  True,  True,  True,  True, False])
tensor([4, 6, 2, 4, 9, 7, 0, 1, 0, 2, 6, 3, 4, 1, 7, 7, 4, 8, 8, 4, 9, 4, 2, 0,
        6, 7, 7, 9, 0, 7, 6, 8, 7, 1, 1, 9, 1, 7, 2, 4, 0, 3, 9, 9, 3, 1, 2, 4,
        3, 6, 6, 1, 7, 7, 8, 3, 9, 6, 7, 6, 2, 3, 3, 7, 1, 

tensor([3, 9, 4, 4, 0, 9, 3, 6, 4, 3, 1, 1, 2, 3, 4, 4, 2, 8, 9, 4, 4, 2, 1, 3,
        3, 7, 0, 4, 1, 6, 6, 4, 9, 7, 4, 9, 6, 3, 4, 4, 7, 1, 3, 4, 5, 6, 8, 1,
        9, 4, 5, 1, 7, 6, 8, 1, 8, 4, 2, 6, 9, 7, 8, 0, 2, 7, 8, 7, 1, 0, 4, 4,
        0, 1, 0, 1, 1, 7, 7, 7, 9, 6, 1, 1, 0, 3, 3, 3, 4, 8, 4, 0, 8, 4, 0, 6,
        1, 0, 4, 8])
tensor([ True, False,  True,  True,  True,  True,  True,  True,  True,  True,
         True, False,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
        False,  True, False, False, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  True, False,  True,  True,
         True,  True, False,  True,

tensor([7, 6, 6, 7, 7, 0, 0, 4, 7, 8, 0, 8, 0, 6, 0, 6, 9, 8, 8, 3, 1, 1, 4, 3,
        8, 8, 7, 7, 2, 9, 4, 2, 0, 3, 5, 1, 7, 1, 8, 1, 0, 6, 2, 0, 0, 9, 7, 7,
        7, 7, 2, 7, 0, 0, 2, 3, 4, 4, 3, 9, 8, 8, 6, 4, 1, 7, 6, 1, 1, 6, 0, 3,
        4, 9, 5, 3, 8, 4, 9, 6, 8, 7, 3, 3, 7, 6, 2, 9, 1, 6, 7, 3, 0, 3, 1, 2,
        6, 1, 8, 9])
tensor([ True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True, False,  True,  True, False,  True,  True, False,
         True, False, False, False, False, False,  True,  True, False,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False, False,
         True, False,  True,  True,  True, False, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,

In [47]:
acc_arr = []
for x in range(100):
    predict_image(test_dataset[x][0], model, test_dataset[x][1], acc_arr)

sm = sum(acc_arr)
print(sm/len(acc_arr))

True
True
True
True
True
True
True
True
5 2
False
9 7
False
True
6 0
False
True
True
True
5 3
False
True
True
True
True
9 7
False
True
True
True
True
True
True
True
True
True
True
True
True
4 0
False
True
True
True
True
2 3
False
True
True
True
True
True
True
5 3
False
1 3
False
True
True
True
True
True
True
5 3
False
6 2
False
True
True
True
True
5 7
False
True
True
True
3 2
False
True
4 9
False
6 2
False
True
True
True
True
True
True
9 8
False
True
True
True
2 7
False
9 1
False
True
7 9
False
True
True
True
True
True
True
True
True
True
True
True
9 8
False
True
True
True
True
7 1
False
True
True
0.79
