In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets 
import torchvision.transforms as transform
from tqdm import tqdm


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class CNN(nn.Module):
    def __init__(self, input_chanels=1, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_chanels, 30, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(30, 15, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(7 * 7 * 15, num_classes)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        out = self.fc(x)
        return out

In [None]:
batch_size = 64
num_epoch = 3
lr = 0.001

In [None]:
load_train = datasets.MNIST('datasets/', transform=transform.ToTensor())
load_test = datasets.MNIST('datasets/', transform=transform.ToTensor(), train=False)

train = DataLoader(load_train, batch_size, shuffle=True)
test = DataLoader(load_test, batch_size, shuffle=True)

In [None]:
model = CNN().to(device)

optimizer = optim.Adam(model.parameters(), lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epoch):
    for batch_idx, (x_batch, y_batch) in enumerate(tqdm(train)):

        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        y_pred = model(x_batch)

        loss = criterion(y_pred, y_batch)
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

In [None]:
def check_accuracy(model, data):
    model.eval()

    num_samples = 0
    num_correct = 0

    with torch.no_grad():

        for x_batch, y_batch in data:

            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            y_pred = model(x_batch)
            _, y_pred = y_pred.max(1)

            num_correct += (y_pred == y_batch).sum()
            num_samples += y_pred.shape[0]

    accuracy = num_correct / num_samples * 100
    model.train()
    print(accuracy, )

In [None]:
check_accuracy(model, test)