In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
# Simple CNN
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 8, kernel_size= (3,3), padding= (1,1), stride = (1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3,3), stride=(1,1), padding= (1,1))
        self.fc1 = nn.Linear(16 * 7 *7, num_classes)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        return x

device = 'cpu'

in_channels = 1
num_classes = 10
lr = 1e-4
batch_size = 16
epochs = 10
load_model = True

In [3]:
import pickle, gzip

with gzip.open('C:/Users/OWNER/Desktop/mnist.pkl.gz', 'rb') as f:
    ((x_train, y_train), (x_test, y_test), _) = pickle.load(f, encoding="latin-1")
    
x_train, y_train, x_test, y_test = map(torch.tensor, (x_train, y_train, x_test, y_test))

train_ds = TensorDataset(x_train, y_train)
test_ds = TensorDataset(x_test, y_test)
train_loader = DataLoader(train_ds, batch_size= batch_size)
test_loader = DataLoader(test_ds, batch_size= batch_size)

In [4]:
model = CNN(in_channels, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

In [5]:
def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [6]:
def load_checkpoint(checkpoint):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

In [7]:
if load_model == True:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"))

=> Loading checkpoint


In [8]:
for epoch in range(epochs):
    losses = []
    
    if epoch % 3 == 0:
        checkpoint = {'state_dict' : model.state_dict(), 'optimizer': optimizer.state_dict()}
        save_checkpoint(checkpoint)
        
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        data = data.view(-1, 1, 28, 28)
        
        scores = model(data)
        loss = criterion(scores, targets)
        losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    cost = sum(losses)/len(losses)
    print(f"cost at epoch {epoch} was {cost:.5f}")

=> Saving checkpoint
cost at epoch 0 was 0.12599
cost at epoch 1 was 0.11225
cost at epoch 2 was 0.10024
=> Saving checkpoint
cost at epoch 3 was 0.09089
cost at epoch 4 was 0.08331
cost at epoch 5 was 0.07688
=> Saving checkpoint
cost at epoch 6 was 0.07158
cost at epoch 7 was 0.06710
cost at epoch 8 was 0.06325
=> Saving checkpoint
cost at epoch 9 was 0.05989
