In [9]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import Dataset,DataLoader,TensorDataset
from tqdm import tqdm
import torch.nn.functional as F

from Models.ResNetmodel import resnet34
from DataUtils.load_dataset import QuickDrawDataset

In [10]:
#훈련 파라미터
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
num_epochs=30
gamma=0.1
image_size=28
learning_rate=0.1
lr_decay_step=[12,20]
momentum=0.9
weight_decay=0.0005

In [11]:
#훈련데이터
train_dataset=QuickDrawDataset(dtype='train')
train_loader=DataLoader(dataset=train_dataset,batch_size=256,shuffle=True)

#테스트 데이터
test_dataset=QuickDrawDataset(dtype='test')
test_loader=DataLoader(dataset=test_dataset,batch_size=64,shuffle=True)

num_classes=train_dataset.get_num_classes()

print("Train image num:",len(train_dataset))
print("Test image num:",len(test_dataset))

Train image num: 124000
Test image num: 31000


In [12]:
model=resnet34(num_classes).to(DEVICE)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate,
                         momentum=momentum,weight_decay=weight_decay)



In [13]:
#훈련 결과 파라미터
train_loss=0.0
train_accuracy=0.0
test_loss=0.0
test_accuracy=0.0

In [14]:
def train():
    global train_loss
    global train_accuracy

    model.train()
    loss_avg=0.0
    correct=0
    data_loader=tqdm(train_loader,desc='Training')

    for batch_idx,(data,target) in enumerate(data_loader):
        data,target=torch.autograd.Variable(data.to(DEVICE)),torch.autograd.Variable(target.to(DEVICE))

        data=data.view(-1,1,28,28)
        data/=255.0

        output=model(data)

        optimizer.zero_grad()
        loss=F.cross_entropy(output,target)
        loss.backward()
        optimizer.step()

        pred=output.data.max(1)[1]
        correct=correct+float(pred.eq(target.data).sum())
        loss_avg=loss_avg*0.2+float(loss)*0.8

    train_loss=loss_avg
    train_accuracy=correct/len(train_loader.dataset)
    print(correct,"a:",train_accuracy)

In [15]:
def test():
    global test_loss
    global test_accuracy
    model.eval()
    loss_avg = 0.0
    correct = 0

    data_loader=tqdm(test_loader,desc='Testing')

    for batch_idx,(data,target) in enumerate(data_loader):
        data,target=torch.autograd.Variable(data.to(DEVICE)),torch.autograd.Variable(target.to(DEVICE))

        data=data.view(-1,1,28,28)
        data/=255.0

        output=model(data)

        optimizer.zero_grad()
        loss=F.cross_entropy(output,target)

        pred=output.data.max(1)[1]
        correct=correct+float(pred.eq(target.data).sum())



        loss_avg=loss_avg+float(loss)

    test_loss=loss_avg/len(test_loader)
    test_accuracy=correct/len(test_loader.dataset)
    print(correct,"a:",test_accuracy)


In [16]:
best_accuracy=0.0

for epoch in range(num_epochs):
    print("epoch "+str(epoch+1)+" is running...")
    if epoch+1 in lr_decay_step:
        learning_rate=learning_rate*gamma
        for param_group in optimizer.param_groups:
            param_group['lr']=learning_rate

    current_epoch=epoch+1
    train()
    test()
    print("test accuracy:",test_accuracy)

    if test_accuracy > best_accuracy:
        best_accuracy=test_accuracy
        torch.save(model.state_dict(),os.path.join("./",'model'+str(epoch)+'.pytorch'))

    print('Best Accuracy: %.4f' %best_accuracy)

epoch 1 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.37it/s]


76625.0 a: 0.6179435483870968


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.80it/s]


22269.0 a: 0.7183548387096774
test accuracy: 0.7183548387096774
Best Accuracy: 0.7184
epoch 2 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


93566.0 a: 0.7545645161290323


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.82it/s]


21297.0 a: 0.687
test accuracy: 0.687
Best Accuracy: 0.7184
epoch 3 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.47it/s]


96949.0 a: 0.7818467741935484


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.82it/s]


23702.0 a: 0.7645806451612903
test accuracy: 0.7645806451612903
Best Accuracy: 0.7646
epoch 4 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.46it/s]


98734.0 a: 0.796241935483871


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.27it/s]


23856.0 a: 0.7695483870967742
test accuracy: 0.7695483870967742
Best Accuracy: 0.7695
epoch 5 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.47it/s]


99729.0 a: 0.804266129032258


Testing: 100%|██████████| 485/485 [00:05<00:00, 94.79it/s]


24262.0 a: 0.7826451612903226
test accuracy: 0.7826451612903226
Best Accuracy: 0.7826
epoch 6 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


100313.0 a: 0.8089758064516129


Testing: 100%|██████████| 485/485 [00:05<00:00, 94.86it/s]


24539.0 a: 0.7915806451612903
test accuracy: 0.7915806451612903
Best Accuracy: 0.7916
epoch 7 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.47it/s]


101053.0 a: 0.8149435483870968


Testing: 100%|██████████| 485/485 [00:05<00:00, 94.94it/s]


24301.0 a: 0.7839032258064517
test accuracy: 0.7839032258064517
Best Accuracy: 0.7916
epoch 8 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


101575.0 a: 0.8191532258064517


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.60it/s]


23939.0 a: 0.7722258064516129
test accuracy: 0.7722258064516129
Best Accuracy: 0.7916
epoch 9 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


102002.0 a: 0.8225967741935484


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.03it/s]


24354.0 a: 0.7856129032258065
test accuracy: 0.7856129032258065
Best Accuracy: 0.7916
epoch 10 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


102526.0 a: 0.8268225806451613


Testing: 100%|██████████| 485/485 [00:05<00:00, 95.45it/s]


24285.0 a: 0.7833870967741936
test accuracy: 0.7833870967741936
Best Accuracy: 0.7916
epoch 11 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


102878.0 a: 0.8296612903225806


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.45it/s]


24064.0 a: 0.776258064516129
test accuracy: 0.776258064516129
Best Accuracy: 0.7916
epoch 12 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


110063.0 a: 0.8876048387096774


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.73it/s]


25529.0 a: 0.823516129032258
test accuracy: 0.823516129032258
Best Accuracy: 0.8235
epoch 13 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.47it/s]


113210.0 a: 0.9129838709677419


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.28it/s]


25404.0 a: 0.8194838709677419
test accuracy: 0.8194838709677419
Best Accuracy: 0.8235
epoch 14 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


115529.0 a: 0.9316854838709677


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.51it/s]


25197.0 a: 0.8128064516129032
test accuracy: 0.8128064516129032
Best Accuracy: 0.8235
epoch 15 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


117842.0 a: 0.9503387096774194


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.61it/s]


25098.0 a: 0.8096129032258065
test accuracy: 0.8096129032258065
Best Accuracy: 0.8235
epoch 16 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


119598.0 a: 0.9645


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.30it/s]


24947.0 a: 0.804741935483871
test accuracy: 0.804741935483871
Best Accuracy: 0.8235
epoch 17 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.50it/s]


120641.0 a: 0.9729112903225806


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.42it/s]


24854.0 a: 0.801741935483871
test accuracy: 0.801741935483871
Best Accuracy: 0.8235
epoch 18 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


121242.0 a: 0.977758064516129


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.25it/s]


24905.0 a: 0.8033870967741935
test accuracy: 0.8033870967741935
Best Accuracy: 0.8235
epoch 19 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


121759.0 a: 0.9819274193548387


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.34it/s]


24671.0 a: 0.7958387096774193
test accuracy: 0.7958387096774193
Best Accuracy: 0.8235
epoch 20 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.50it/s]


123160.0 a: 0.993225806451613


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.61it/s]


24986.0 a: 0.806
test accuracy: 0.806
Best Accuracy: 0.8235
epoch 21 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.47it/s]


123705.0 a: 0.9976209677419355


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.88it/s]


24937.0 a: 0.8044193548387096
test accuracy: 0.8044193548387096
Best Accuracy: 0.8235
epoch 22 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


123805.0 a: 0.9984274193548387


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.26it/s]


24929.0 a: 0.8041612903225807
test accuracy: 0.8041612903225807
Best Accuracy: 0.8235
epoch 23 is running...


Training: 100%|██████████| 485/485 [00:56<00:00,  8.51it/s]


123838.0 a: 0.9986935483870968


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.90it/s]


24938.0 a: 0.8044516129032258
test accuracy: 0.8044516129032258
Best Accuracy: 0.8235
epoch 24 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


123856.0 a: 0.9988387096774194


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.03it/s]


24926.0 a: 0.8040645161290323
test accuracy: 0.8040645161290323
Best Accuracy: 0.8235
epoch 25 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.51it/s]


123862.0 a: 0.9988870967741935


Testing: 100%|██████████| 485/485 [00:05<00:00, 96.89it/s]


24939.0 a: 0.804483870967742
test accuracy: 0.804483870967742
Best Accuracy: 0.8235
epoch 26 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


123875.0 a: 0.998991935483871


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.51it/s]


24926.0 a: 0.8040645161290323
test accuracy: 0.8040645161290323
Best Accuracy: 0.8235
epoch 27 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.48it/s]


123883.0 a: 0.9990564516129032


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.78it/s]


24925.0 a: 0.8040322580645162
test accuracy: 0.8040322580645162
Best Accuracy: 0.8235
epoch 28 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.50it/s]


123895.0 a: 0.9991532258064516


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.49it/s]


24919.0 a: 0.8038387096774193
test accuracy: 0.8038387096774193
Best Accuracy: 0.8235
epoch 29 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


123900.0 a: 0.9991935483870967


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.32it/s]


24935.0 a: 0.8043548387096774
test accuracy: 0.8043548387096774
Best Accuracy: 0.8235
epoch 30 is running...


Training: 100%|██████████| 485/485 [00:57<00:00,  8.49it/s]


123910.0 a: 0.9992741935483871


Testing: 100%|██████████| 485/485 [00:04<00:00, 97.25it/s]

24941.0 a: 0.8045483870967742
test accuracy: 0.8045483870967742
Best Accuracy: 0.8235



