In [91]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
mnist.data.shape, mnist.target.shape
train = mnist.data.to_numpy()
target = mnist.target.to_numpy()
train.shape, target.shape

In [148]:
import torch 
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, model="linear"):
        super().__init__()
        self.model_type = model
        if model =="linear":
            self.linear1 = nn.Linear(784, 128)
            self.linear2 = nn.Linear(128, 10)
        elif model == "cnn":
            self.conv1 = nn.Conv2d(in_channels= 1, out_channels=32, kernel_size=3, stride=1)	
            self.conv2 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=3, stride=2)
            self.fc1 = nn.Linear(in_features=4032, out_features=512)
            self.fc2 = nn.Linear(in_features=512, out_features=10)
            
    def forward(self, x):
        batch_size = x.size(0)

        if self.model_type == "linear":
            x = x.view(batch_size, -1)
            x = self.linear1(x)
            x = nn.functional.relu(x)
            x = self.linear2(x) 
        else:

            x = nn.functional.relu(self.conv1(x))
            x = nn.functional.relu(self.conv2(x))
            x = x.contiguous().view(batch_size, -1)
            x = self.fc1(x)
            x = nn.functional.relu(x)
            x = self.fc2(x)
        return x 



In [149]:
linear_model = Model(model="linear").float()
cnn_model = Model(model="cnn").float()

# Todo : Cuda 

In [150]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms 
import numpy as np


class MnistDataset(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = data
        self.target = target
        self.transform = transform 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx] 
        sample = sample.reshape(28,28,1).astype(np.float32)
        sample = self.transform(sample)
        target = torch.tensor(int(self.target[idx]), dtype=int)
        return sample, target

transform = transforms.Compose(
    [transforms.ToPILImage(),
        transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

dataset = MnistDataset(train, target, transform=transform)
dataloader = DataLoader(dataset, batch_size=32,
                        shuffle=True, num_workers=0)


In [155]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

models = [linear_model, cnn_model]
optimizers  = [optim.SGD(linear_model.parameters(), lr=0.001, momentum=0.9), 
                optim.SGD(cnn_model.parameters(), lr=0.001, momentum=0.9)]

epochs = 10
for epoch in range(epochs):
    running_loss_1 = 0
    running_loss_2 = 0
    for i_batch, (data, labels) in enumerate(dataloader):

        # zero the parameter gradients
        [op.zero_grad() for op in optimizers]

        # forward + backward + optimize
        loss = [criterion(net(data).softmax(dim=1), labels) for net in models]
        [l.backward() for l in loss]
        [op.step() for op in optimizers]

        # print statistics
        
        running_loss_1  += loss[0].item()
        running_loss_2  += loss[1].item()
        if i_batch % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i_batch + 1:5d}] loss: {running_loss_1 / 2000:.3f}  {running_loss_2 / 2000:.3f}')
            running_loss = 0.0
    break

KeyboardInterrupt: 