In [None]:
from torchvision.datasets import MNIST

train_dataset = MNIST(root='MNIST')
len(train_dataset)

In [None]:
test_dataset = MNIST(root='MNIST',train = False)
len(test_dataset)

In [None]:
import matplotlib.pyplot as plt

image,label = train_dataset[0]

print(image)
plt.imshow(image,cmap = 'gray')

In [None]:
import torchvision.transforms as transforms

train_dataset = MNIST(root='MNIST',transform = transforms.ToTensor())
print(train_dataset[0])

test_dataset = MNIST(root='MNIST',train = False,transform = transforms.ToTensor())
print(test_dataset[0])

In [None]:
img_tensor = test_dataset[0][0]
print(img_tensor[:,10:15,10:15])

plt.imshow(img_tensor[0,5:25,5:25],cmap = 'gray')

In [None]:
import numpy as np

def split_indices(dataset_len,val_percent):
    validation_n = int(dataset_len*val_percent)
    indx = np.random.permutation(dataset_len)
    # first - train_indx, second - val_indx
    return indx[validation_n:],indx[:validation_n]

In [None]:
train_indices, val_indices = split_indices(len(train_dataset),0.2)

In [None]:
from torch.utils.data import TensorDataset, DataLoader,SubsetRandomSampler

batch_size_ = 100

train_loader = DataLoader(train_dataset,
                          batch_size= batch_size_,
                          shuffle = False,
                          sampler=SubsetRandomSampler(train_indices))
val_loader = DataLoader(train_dataset,
                        batch_size= batch_size_,
                        shuffle = False,
                        sampler=SubsetRandomSampler(val_indices))


In [None]:
for x_batch,y_batch in train_loader:
    print(x_batch)
    print(y_batch)

In [None]:
from torch.nn import Linear


input_size = 28*28
output_size = 10 #num of classes
model = Linear(input_size,output_size)

In [None]:
print(model.weight.shape)
print(model.bias.shape)

In [None]:
from torch.nn import Module

class MnistModel(Module):
    def __init__(self,input_size,output_size):
        super().__init__()
        self.linear = Linear(input_size,output_size)
    
    def forward(self,xb):
        xb = xb.reshape(-1,input_size)
        out = self.linear(xb)
        return out

In [None]:
model = MnistModel(input_size,output_size)
print(list(model.parameters()))

In [None]:
for images,labels in train_loader:
    outputs = model(images)
    break

print(outputs.shape)
#print(images)
#print(images.shape)
print(outputs)

In [None]:
from torch.nn.functional import softmax

probs = softmax(outputs, dim = 1)

print(probs)

In [None]:
import torch
max_probs, pred = torch.max(probs,dim =1)
print(max_probs, pred )

In [None]:
def accuracy(lbl_pred,lbl_true):
    return torch.sum(lbl_pred==lbl_true) / len(lbl_pred)

In [None]:
from torch.nn.functional import cross_entropy
from torch.optim import SGD
opt = SGD(model.parameters(), lr = 0.001)

n_epochs = 10
for epoch in range(n_epochs):
    for images,labels in train_loader:
        opt.zero_grad()
        y_pred = model(images)
        probs = softmax(y_pred, dim = 1)
        max_probs, pred = torch.max(probs,dim =1)
        loss = cross_entropy(y_pred,labels)
        accuracy_ = accuracy(pred,labels)
        loss.backward()
        opt.step()
        print(f'Accuracy {epoch+1}/{n_epochs} : {accuracy_}')

In [None]:
img, label = test_dataset[10]

def predict(model,img):
    output = model(img)
    probs = softmax(output, dim = 1)
    _, pred = torch.max(probs,dim =1)
    return pred.item()

print(predict(model,img),label)

In [None]:
torch.save(model.state_dict(), 'mnist-logistic.pth')

In [None]:
model2 = MnistModel(input_size,output_size)
model2.load_state_dict(torch.load('mnist-logistic.pth'))

img, label = test_dataset[159]

print('Predicted: ',predict(model,img))
print('True: ',label)