In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import os
import re

In [2]:
batch_size = 100
epochs = 10
data_root  = './data/'
num_workers = 4

In [3]:
train_set = datasets.FashionMNIST(
    root=data_root,
    train='true',
    download=True,
    transform=transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [4]:
class Network(nn.Module):
    def __init__(self) -> None:
        super(Network, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.fc3 = nn.Linear(in_features=60, out_features=10)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        
        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x

In [5]:
model = Network().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [6]:
def train(model, loader):
    if not os.path.exists('./checkpoints/'):
        os.mkdir('checkpoints')
    points = os.listdir('./checkpoints/')
    epoch_checkpoint = 1
    if len(points) != 0:
        points.sort()
        path = os.path.join('./checkpoints/', points[-1])
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_checkpoint = checkpoint['epoch'] + 1
    for epoch in range(epoch_checkpoint, epochs + 1):
        train_loss = 0
        correct = 0

        for step, data in enumerate(loader, 1):
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()
            preds = model(images)
            loss = F.cross_entropy(preds, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss
            correct += preds.argmax(dim=1).eq(labels).sum().item()
        train_loss /= step
        state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        torch.save(state, f'./checkpoints/model_epoch_{epoch}.pkl')
        print(f'epoch: {epoch}/{epochs} loss: {train_loss} correct: {correct}')

In [None]:
train(model, train_loader)