In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import numpy as np
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt

from progress.bar import IncrementalBar as Bar
from random import randint

%matplotlib inline

In [2]:

class Net(nn.Module):
    '''net'''

    def __init__(self):
        super().__init__()

        # None * 1 * 32 * 32
        self.__conv1 = nn.Conv2d(3, 32, 3)
        self.__conv2 = nn.Conv2d(32, 96, 3)
        self.__fc1 = nn.Linear(96 * 28 * 28, 128)
        self.__fc2 = nn.Linear(128, 64)
        self.__fc3 = nn.Linear(64, 10)

    def forward(self, x):
        out = self.__conv1(x)
        out = F.relu(out)

        out = self.__conv2(out)
        out = F.relu(out)

        out = out.view(-1, 96 * 28 * 28)
        out = self.__fc1(out)
        out = F.relu(out)

        out = self.__fc2(out)
        out = F.relu(out)

        out = self.__fc3(out)

        return out


def define_loader():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5))
    ])

    train_datasets = torchvision.datasets.CIFAR10(
        root='./data', train=True, transform=transform, download=True)
    test_datasets = torchvision.datasets.CIFAR10(
        root='./data', train=False, transform=transform, download=True)

    train_loader = torch.utils.data.DataLoader(
        train_datasets, batch_size=10, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(
        test_datasets, batch_size=10, shuffle=False, num_workers=4)

    return train_loader, test_loader


def imshow(img):
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))

    plt.imshow(img)
    plt.savefig('./data/charts/1.png')


def train_net(net: nn.Module, train, test, n_epoch=10, lr=1e-3, momentum=.9):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)

    step_num = 250
    bar = Bar('train progress', max=n_epoch*len(train) // step_num)

    index = 0
    running_loss = 0.
    train_losses = []
    test_losses = []
    for _ in range(n_epoch):
        for ind, data in enumerate(train):
            images, labels = data

            # start training
            lables_ = net(images)
            loss = criterion(lables_, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # print log
            if ind % step_num == 0:
                with torch.no_grad():
                    ind = randint(0, len(test))

                    for i, test_data in enumerate(test):
                        if i == ind:
                            test_images, test_labels = test_data
                            test_predict = net(test_images)
                            test_loss = criterion(test_predict, test_labels)
                            test_losses.append(test_loss.item())
                            break

                    mean_train_loss = running_loss / step_num
                    train_losses.append(mean_train_loss)

                    index += 1
                    running_loss = 0.
                    bar.next()

    index = np.arange(1, index)
    plt.plot(index, train_losses[1:])
    plt.plot(index, test_losses[1:])


In [3]:
def main():
    train, test = define_loader()
    net = Net()
    train_net(net, train, test, n_epoch=20, lr=0.001, momentum=.9)

In [None]:
main()