In [2]:
import torch

In [3]:
from torch import nn, optim

In [4]:
from torch.utils import data

In [5]:
from loguru import logger

In [6]:
import visualize

In [7]:
import numpy as np

In [8]:
import time

In [9]:
import torchvision
import torchvision.transforms

In [10]:
import math

In [11]:
import base, resnet

In [12]:
root = '/data/CIFAR10'

In [13]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [14]:
train_transform = torchvision.transforms.Compose(
    [torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
     torchvision.transforms.RandomRotation(10),
     torchvision.transforms.RandomHorizontalFlip(),
#     torchvision.transforms.RandomCrop(size=[32,32], padding=4, pad_if_needed=True),
    torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.49140089750289917, 0.4821591377258301, 0.4465310275554657),
                                      (0.24702748656272888, 0.24348321557044983, 0.26158758997917175))
    ])

In [15]:
test_transform = torchvision.transforms.Compose(
    
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.49140089750289917, 0.4821591377258301, 0.4465310275554657),
                                      (0.24702748656272888, 0.24348321557044983, 0.26158758997917175))
    ])

In [16]:
train_dataset = torchvision.datasets.CIFAR10(root, transform=train_transform, download=True)

Files already downloaded and verified


In [17]:
test_dataset = torchvision.datasets.CIFAR10(root, train=False, transform=test_transform, download=True)

Files already downloaded and verified


In [18]:
train_dt, valid_dt = data.random_split(train_dataset, (40000, 10000))

In [19]:
net = resnet.ResNet56()

2019-07-17 17:41:59.826 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.827 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.828 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.829 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.830 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.831 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.832 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.833 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channels: 16, stride: 1
2019-07-17 17:41:59.834 | DEBUG    | resnet:get_residual_unit:158 - in_channels: 16, out_channel

In [20]:
def learning_rate_schedules(batch_num, initial_lr, warmup_batchs, total_batchs):
    """
    """
    if batch_num < warmup_batchs:
        lr = initial_lr * (batch_num+1)/warmup_batchs
    else:
        num = batch_num - warmup_batchs + 1
        lr = 0.5 * (1 + math.cos(math.pi * (num / total_batchs))) * initial_lr
    return lr

In [21]:
batch_size = 2048

In [22]:
epoch = 200

In [23]:
use_cuda = net.use_cuda

In [24]:
if use_cuda:
    net = net.cuda()

In [25]:
use_cuda = net.use_cuda

In [26]:
visualizer = visualize.Visualizer('Net_sgd')



In [27]:
learning_rate = 0.1

In [28]:
# optimizer = optim.Adam(net.parameters(), lr=learning_rate)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

In [29]:
total_batch_num = (int(40000 / batch_size) + 1) * epoch

In [30]:
logger.add('resnet_56_sgd_0.log')

1

In [32]:
batch_index = 0
for e in range(epoch):
    start_time = time.time()
    net.train()
    train_dt, valid_dt = data.random_split(train_dataset, (40000, 10000))
    train_data_iter = data.DataLoader(train_dt, batch_size=batch_size, shuffle=True)
    running_loss = 0.0
    for i, (train_data, label) in enumerate(train_data_iter):
#         lr = learning_rate_schedules(batch_index, learning_rate, 5, total_batch_num)
#         optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
        if use_cuda:
            train_data = train_data.cuda()
            label = label.cuda()
        optimizer.zero_grad()
        prediction = net(train_data)
        loss = net.loss(prediction, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch_index += 1
    if e == 20:
        learning_rate = learning_rate / 10
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    if e == 80:
        learning_rate = learning_rate / 10
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    valid_train_res, valid_ap = base.validate(net, train_dataset, batch_size, num_class=10)
    valid_test_res, test_ap = base.validate(net, test_dataset, batch_size, num_class=10)
    loss_dct = {'loss': running_loss}
    val_precision = valid_train_res['map']
    val_ap_dct = {classes[i]: v for i, v in enumerate(valid_ap)}
    test_precision = valid_test_res['map']
    logger.info(f'epoch num: {e}, loss: {running_loss}, val_precision: {val_precision}, test_precision: {test_precision}, learning_rate: {learning_rate}, time: {time.time() - start_time}')
    visualizer.plot(loss_dct)
    visualizer.plot({'val_precision': val_precision})
    visualizer.plot(val_ap_dct)
    visualizer.plot({'test_precision': test_precision})

2019-07-17 18:16:58.295 | INFO     | __main__:<module>:33 - epoch num: 0, loss: 27.65863811969757, val_precision: 0.49361997842788696, test_precision: 0.5112999677658081, learning_rate: 0.1, time: 74.74558687210083
2019-07-17 18:18:13.136 | INFO     | __main__:<module>:33 - epoch num: 1, loss: 27.382885336875916, val_precision: 0.49528002738952637, test_precision: 0.5043999552726746, learning_rate: 0.1, time: 74.78244256973267
2019-07-17 18:19:27.929 | INFO     | __main__:<module>:33 - epoch num: 2, loss: 27.280449986457825, val_precision: 0.5008599758148193, test_precision: 0.520799994468689, learning_rate: 0.1, time: 74.73498725891113
2019-07-17 18:20:42.635 | INFO     | __main__:<module>:33 - epoch num: 3, loss: 27.069687724113464, val_precision: 0.5037800073623657, test_precision: 0.5190999507904053, learning_rate: 0.1, time: 74.64627313613892
2019-07-17 18:21:57.327 | INFO     | __main__:<module>:33 - epoch num: 4, loss: 26.60028374195099, val_precision: 0.5061400532722473, test_p

In [149]:
torch.save(net.state_dict(), 'base_net_002.save')

In [2]:
{'loss': 529.4704940319061, 'val_precision': 0.78054, 'test_precision': 0.6408, 'learning_rate': 0.01, 'time': 19.557708978652954}

{'loss': 529.4704940319061,
 'val_precision': 0.78054,
 'test_precision': 0.6408,
 'learning_rate': 0.01,
 'time': 19.557708978652954}