In [1]:
import os
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

# Specify the directory where the dataset is stored
root_dir = './data'

# Ensure the necessary directory structure
if not os.path.exists(os.path.join(root_dir, 'MNIST', 'raw')):
    os.makedirs(os.path.join(root_dir, 'MNIST', 'raw'))


train = datasets.MNIST(root = root_dir, train = True, download=False, transform = transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST(root = root_dir, train = False, download=False, transform = transforms.Compose([transforms.ToTensor()]))


trainset = torch.utils.data.DataLoader(train, batch_size = 10, shuffle = True)
testset = torch.utils.data.DataLoader(test, batch_size = 10, shuffle = False)

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) 
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x)) 
        x = F.relu(self.conv2(x))  
        x = x.view(-1, 64 * 7 * 7)  
        x = F.relu(self.fc1(x))     
        x = self.fc2(x)             
        return F.log_softmax(x, dim=1)  
net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

num_epochs = 3
for epoch in range(num_epochs):
    for data in trainset:
        X, y = data
        net.zero_grad()
        output = net(X)
        loss = loss_function(output, y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} completed out of {num_epochs}, Loss: {loss.item()}")

Epoch 1 completed out of 3, Loss: 4.029192950838478e-06
