In [1]:
from cnn_finetune import make_model
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import os
from tensorboardX import SummaryWriter 
from tqdm import tqdm
import numpy as np
from torchvision import datasets

In [2]:
IMAGE_SIZE=(299,299)
LR = 0.001    
BATCH_SIZE=8
EPOCHS=1000
EARLY_STOP_Threshold=5

In [3]:
import torch.nn as nn

def make_classifier(in_features, num_classes):
    return nn.Sequential(
        nn.Linear(in_features, num_classes)
    )

model = make_model('xception', num_classes=120, pretrained=True, input_size=(299, 299), classifier_factory=make_classifier)

In [4]:
model

XceptionWrapper(
  (_features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Block(
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): SeparableConv2d(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
 

In [5]:
from torchsummary import summary
summary(model.cuda(), (3, 299, 299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
            Conv2d-4         [-1, 64, 147, 147]          18,432
       BatchNorm2d-5         [-1, 64, 147, 147]             128
            Conv2d-6         [-1, 64, 147, 147]             576
            Conv2d-7        [-1, 128, 147, 147]           8,192
   SeparableConv2d-8        [-1, 128, 147, 147]               0
       BatchNorm2d-9        [-1, 128, 147, 147]             256
             ReLU-10        [-1, 128, 147, 147]               0
             ReLU-11        [-1, 128, 147, 147]               0
           Conv2d-12        [-1, 128, 147, 147]           1,152
           Conv2d-13        [-1, 128, 147, 147]          16,384
  SeparableConv2d-14        [-1, 128, 1

In [6]:
# convert data to a normalized torch.FloatTensor
train_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.ToTensor(),
                                      
                                      transforms.ColorJitter(),
                                      transforms.RandomRotation(30),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

valid_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

In [7]:
# choose the training and test datasets
train_data = datasets.ImageFolder('data/train', transform=train_transforms)
val_data = datasets.ImageFolder('data/val',transform=valid_transforms)

In [8]:
class_dir=train_data.class_to_idx
#print(class_dir)

In [9]:
# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE,shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE,shuffle=True)

In [10]:
import matplotlib.pyplot as plt
%matplotlib inline
classes = list(class_dir.keys())
mean , std = torch.tensor([0.485, 0.456, 0.406]),torch.tensor([0.229, 0.224, 0.225])


def denormalize(image):
  image = transforms.Normalize(-mean/std,1/std)(image) #denormalize
  image = image.permute(1,2,0) #Changing from 3x224x224 to 224x224x3
  image = torch.clamp(image,0,1)
  return image

# helper function to un-normalize and display an image
def imshow(img):
    img = denormalize(img) 
    plt.imshow(img)

In [11]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE,shuffle=True)
print("train_loader:",len(train_loader))
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE,shuffle=True)
print("val_loader:",len(valid_loader))

train_loader: 1973
val_loader: 450


In [12]:
model.cuda()
print(model)  # net architecture

XceptionWrapper(
  (_features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Block(
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): SeparableConv2d(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
 

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)   # optimize all cnn parameters
#optimizer = torch.optim.RMSprop(cnn.parameters())   # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()   # the target label is not one-hotted

In [14]:
root_logdir = os.path.join(os.curdir, "my_logs")
def get_run_logdir():
    import time
    run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    return os.path.join(root_logdir, run_id)

run_logdir = get_run_logdir()
model_name=run_logdir.split("/")[-1]
run_logdir

'./my_logs/run_2021_01_07-20_31_01'

In [15]:
from torch.autograd import Variable
# training and testing
writer = SummaryWriter(run_logdir)
Max_val_accuracy=0
no_improve_count=0
for epoch in range(EPOCHS):
    train_loss=0.0
    val_loss=0.0
    #train_accuracy=0.0
    #val_accuracy=0.0
    train_correct = 0.
    train_total = 0.
    val_correct = 0.
    val_total = 0.
    
    for step, (x, y) in enumerate(tqdm(train_loader)):   # 分配 batch data, normalize x when iterate train_loader
        b_x=Variable(x).cuda()
        b_y=Variable(y).cuda()
        
        output = model(b_x)               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        #print("loss:",loss)
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()               # backpropagation, compute gradients
        optimizer.step()                # apply gradients
        train_loss+=loss.item()*x.size(0)
        
        pred = output.data.max(1, keepdim=True)[1].cpu()
        # compare predictions to true label
        train_correct += np.sum(np.squeeze(pred.eq(y.data.view_as(pred))).cpu().numpy())
        train_total += x.size(0)
    
    for step, (x, y) in enumerate(valid_loader):  # 每一步 loader 释放一小批数据用来学习
        # 假设这里就是你训练的地方...
        b_x=Variable(x).cuda()
        b_y=Variable(y).cuda()
        # 打出来一些数据
        output = model(b_x)
        loss = loss_func(output, b_y)
        val_loss+=loss.item()*x.size(0)
         
        pred = output.data.max(1, keepdim=True)[1].cpu()
        # compare predictions to true label
        
        val_correct += np.sum(np.squeeze(pred.eq(y.data.view_as(pred))).cpu().numpy())
        val_total += x.size(0)
        
    train_loss = train_loss/len(train_loader.dataset)
    val_loss=val_loss/len(valid_loader.dataset)
    train_accuracy=round((100. *train_correct / train_total),2)
    val_accuracy=round((100. *val_correct / val_total),2)
    writer.add_scalars('loss', {'train_loss':train_loss,'valid_loss':val_loss}, epoch)
    writer.add_scalars('accuracy', {'train_accuracy':train_accuracy,'valid_accuracy':val_accuracy}, epoch)
    print('epoch ',epoch,' , train loss : ',train_loss,', train_accuracy:',train_accuracy,' , valid loss :',val_loss,', val_accuracy:',val_accuracy)
    
    if(val_accuracy>Max_val_accuracy):
        Max_val_accuracy=val_accuracy
        no_improve_count=0
        torch.save(model.state_dict(), run_logdir+'_weights.pkl')
        print("save model, current best val_accuracy:",Max_val_accuracy)
    else:
        no_improve_count+=1
    if(no_improve_count>=EARLY_STOP_Threshold):
        break


100%|██████████| 1973/1973 [06:16<00:00,  5.25it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  0  , train loss :  4.745896166106339 , train_accuracy: 1.7  , valid loss : 4.578082690238952 , val_accuracy: 1.67
save model, current best val_accuracy: 1.67


100%|██████████| 1973/1973 [06:07<00:00,  5.37it/s]


epoch  1  , train loss :  4.403734483814971 , train_accuracy: 2.95  , valid loss : 4.365975757704841 , val_accuracy: 3.03


  0%|          | 1/1973 [00:00<06:09,  5.34it/s]

save model, current best val_accuracy: 3.03


100%|██████████| 1973/1973 [05:45<00:00,  5.71it/s]


epoch  2  , train loss :  4.286642393780885 , train_accuracy: 4.11  , valid loss : 4.283263317214118 , val_accuracy: 3.94


  0%|          | 1/1973 [00:00<06:10,  5.32it/s]

save model, current best val_accuracy: 3.94


100%|██████████| 1973/1973 [05:44<00:00,  5.73it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  3  , train loss :  4.2145597417361085 , train_accuracy: 4.22  , valid loss : 4.228917011155023 , val_accuracy: 3.92


100%|██████████| 1973/1973 [05:42<00:00,  5.76it/s]


epoch  4  , train loss :  4.156142563744034 , train_accuracy: 4.79  , valid loss : 4.152307096587287 , val_accuracy: 3.97


  0%|          | 1/1973 [00:00<06:04,  5.41it/s]

save model, current best val_accuracy: 3.97


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  5  , train loss :  4.00627447708273 , train_accuracy: 5.79  , valid loss : 3.9850289323594836 , val_accuracy: 6.44


  0%|          | 1/1973 [00:00<06:06,  5.38it/s]

save model, current best val_accuracy: 6.44


100%|██████████| 1973/1973 [05:39<00:00,  5.81it/s]


epoch  6  , train loss :  3.810696746106459 , train_accuracy: 7.89  , valid loss : 3.80162294599745 , val_accuracy: 8.36


  0%|          | 1/1973 [00:00<05:56,  5.52it/s]

save model, current best val_accuracy: 8.36


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  7  , train loss :  3.6858621290317997 , train_accuracy: 9.8  , valid loss : 3.674140354262458 , val_accuracy: 9.33


  0%|          | 1/1973 [00:00<06:03,  5.42it/s]

save model, current best val_accuracy: 9.33


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  8  , train loss :  3.5705615636611725 , train_accuracy: 10.72  , valid loss : 3.610132368935479 , val_accuracy: 10.11


  0%|          | 1/1973 [00:00<06:07,  5.36it/s]

save model, current best val_accuracy: 10.11


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  9  , train loss :  3.4116901496160965 , train_accuracy: 13.32  , valid loss : 3.4610679483413698 , val_accuracy: 12.44


  0%|          | 1/1973 [00:00<06:06,  5.38it/s]

save model, current best val_accuracy: 12.44


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  10  , train loss :  3.2736502954492823 , train_accuracy: 14.92  , valid loss : 3.3089305464426677 , val_accuracy: 13.67


  0%|          | 1/1973 [00:00<06:15,  5.25it/s]

save model, current best val_accuracy: 13.67


100%|██████████| 1973/1973 [05:39<00:00,  5.81it/s]


epoch  11  , train loss :  3.151006791141219 , train_accuracy: 17.06  , valid loss : 3.2795512178209094 , val_accuracy: 14.83


  0%|          | 1/1973 [00:00<06:22,  5.16it/s]

save model, current best val_accuracy: 14.83


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  12  , train loss :  3.081050160704959 , train_accuracy: 17.74  , valid loss : 3.2415892963939243 , val_accuracy: 16.83


  0%|          | 1/1973 [00:00<06:12,  5.29it/s]

save model, current best val_accuracy: 16.83


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  13  , train loss :  3.0118450519211946 , train_accuracy: 19.6  , valid loss : 3.143849498960707 , val_accuracy: 17.5


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 17.5


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  14  , train loss :  2.962388872518551 , train_accuracy: 20.56  , valid loss : 3.1561494000752766 , val_accuracy: 17.25


100%|██████████| 1973/1973 [05:39<00:00,  5.81it/s]


epoch  15  , train loss :  2.896212814727417 , train_accuracy: 21.34  , valid loss : 3.110415157477061 , val_accuracy: 18.31


  0%|          | 1/1973 [00:00<05:52,  5.59it/s]

save model, current best val_accuracy: 18.31


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  16  , train loss :  2.82822743570618 , train_accuracy: 22.76  , valid loss : 3.0218755729993183 , val_accuracy: 20.19


  0%|          | 1/1973 [00:00<05:52,  5.60it/s]

save model, current best val_accuracy: 20.19


100%|██████████| 1973/1973 [05:38<00:00,  5.83it/s]
  0%|          | 1/1973 [00:00<05:38,  5.82it/s]

epoch  17  , train loss :  2.792148658928441 , train_accuracy: 23.24  , valid loss : 3.0092999341752793 , val_accuracy: 19.78


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  18  , train loss :  2.713296422200463 , train_accuracy: 24.72  , valid loss : 2.985661091539595 , val_accuracy: 21.58


  0%|          | 1/1973 [00:00<05:49,  5.64it/s]

save model, current best val_accuracy: 21.58


100%|██████████| 1973/1973 [05:46<00:00,  5.70it/s]


epoch  19  , train loss :  2.676281605997393 , train_accuracy: 25.81  , valid loss : 2.896484646267361 , val_accuracy: 22.19


  0%|          | 1/1973 [00:00<06:23,  5.14it/s]

save model, current best val_accuracy: 22.19


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  20  , train loss :  2.612349528604243 , train_accuracy: 26.73  , valid loss : 2.8713919626341924 , val_accuracy: 22.86


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 22.86


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  21  , train loss :  2.5125164281688073 , train_accuracy: 28.35  , valid loss : 2.7850259224573772 , val_accuracy: 24.5


  0%|          | 1/1973 [00:00<06:03,  5.42it/s]

save model, current best val_accuracy: 24.5


100%|██████████| 1973/1973 [05:39<00:00,  5.82it/s]


epoch  22  , train loss :  2.4060978168346967 , train_accuracy: 31.52  , valid loss : 2.719565830760532 , val_accuracy: 26.92


  0%|          | 1/1973 [00:00<06:11,  5.32it/s]

save model, current best val_accuracy: 26.92


100%|██████████| 1973/1973 [05:38<00:00,  5.83it/s]


epoch  23  , train loss :  2.324350468581587 , train_accuracy: 32.44  , valid loss : 2.682601016097599 , val_accuracy: 27.83


  0%|          | 1/1973 [00:00<06:14,  5.26it/s]

save model, current best val_accuracy: 27.83


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  24  , train loss :  2.264321516777854 , train_accuracy: 34.42  , valid loss : 2.680494812859429 , val_accuracy: 28.83


  0%|          | 1/1973 [00:00<06:19,  5.19it/s]

save model, current best val_accuracy: 28.83


100%|██████████| 1973/1973 [05:38<00:00,  5.82it/s]


epoch  25  , train loss :  2.181957296669668 , train_accuracy: 36.17  , valid loss : 2.605824468400743 , val_accuracy: 29.94


  0%|          | 1/1973 [00:00<05:52,  5.60it/s]

save model, current best val_accuracy: 29.94


100%|██████████| 1973/1973 [05:38<00:00,  5.83it/s]
  0%|          | 1/1973 [00:00<05:31,  5.95it/s]

epoch  26  , train loss :  2.14539528050209 , train_accuracy: 37.89  , valid loss : 2.6081701425711312 , val_accuracy: 29.44


100%|██████████| 1973/1973 [05:38<00:00,  5.83it/s]


epoch  27  , train loss :  2.044620782569437 , train_accuracy: 39.29  , valid loss : 2.585956199698978 , val_accuracy: 31.11


  0%|          | 1/1973 [00:00<05:35,  5.89it/s]

save model, current best val_accuracy: 31.11


100%|██████████| 1973/1973 [05:38<00:00,  5.83it/s]


epoch  28  , train loss :  1.9791188016576349 , train_accuracy: 40.92  , valid loss : 2.6059053349494934 , val_accuracy: 31.97


  0%|          | 1/1973 [00:00<06:17,  5.22it/s]

save model, current best val_accuracy: 31.97


100%|██████████| 1973/1973 [05:42<00:00,  5.75it/s]


epoch  29  , train loss :  1.9515381361400306 , train_accuracy: 41.51  , valid loss : 2.561903702682919 , val_accuracy: 33.08


  0%|          | 1/1973 [00:00<06:18,  5.21it/s]

save model, current best val_accuracy: 33.08


100%|██████████| 1973/1973 [05:43<00:00,  5.75it/s]


epoch  30  , train loss :  1.9025309515419444 , train_accuracy: 43.29  , valid loss : 2.514534538057115 , val_accuracy: 34.94


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 34.94


100%|██████████| 1973/1973 [07:08<00:00,  4.61it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  31  , train loss :  1.8095540521987847 , train_accuracy: 45.23  , valid loss : 2.621728361050288 , val_accuracy: 33.03


100%|██████████| 1973/1973 [07:24<00:00,  4.43it/s]


epoch  32  , train loss :  1.7611868225264227 , train_accuracy: 46.68  , valid loss : 2.5384230497148303 , val_accuracy: 35.53


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 35.53


100%|██████████| 1973/1973 [07:23<00:00,  4.45it/s]


epoch  33  , train loss :  1.7453016705352533 , train_accuracy: 47.29  , valid loss : 2.4702437346511417 , val_accuracy: 36.42


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 36.42


100%|██████████| 1973/1973 [06:41<00:00,  4.91it/s]


epoch  34  , train loss :  1.6893817278601324 , train_accuracy: 48.24  , valid loss : 2.4955158176687027 , val_accuracy: 37.36


  0%|          | 1/1973 [00:00<06:01,  5.46it/s]

save model, current best val_accuracy: 37.36


100%|██████████| 1973/1973 [05:51<00:00,  5.62it/s]
  0%|          | 1/1973 [00:00<06:01,  5.45it/s]

epoch  35  , train loss :  1.6228334398367565 , train_accuracy: 50.29  , valid loss : 2.5516058366828496 , val_accuracy: 36.39


100%|██████████| 1973/1973 [06:21<00:00,  5.17it/s]


epoch  36  , train loss :  1.5868358907307334 , train_accuracy: 51.32  , valid loss : 2.431609226067861 , val_accuracy: 38.25


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 38.25


100%|██████████| 1973/1973 [06:26<00:00,  5.11it/s]


epoch  37  , train loss :  1.5614239642493013 , train_accuracy: 51.98  , valid loss : 2.4030416227711573 , val_accuracy: 38.94


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 38.94


100%|██████████| 1973/1973 [06:36<00:00,  4.97it/s]
  0%|          | 1/1973 [00:00<06:15,  5.25it/s]

epoch  38  , train loss :  1.4687791605837044 , train_accuracy: 54.62  , valid loss : 2.522684846056832 , val_accuracy: 38.69


100%|██████████| 1973/1973 [06:13<00:00,  5.28it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  39  , train loss :  1.43697383988681 , train_accuracy: 55.62  , valid loss : 2.497762788799074 , val_accuracy: 38.56


100%|██████████| 1973/1973 [06:24<00:00,  5.13it/s]


epoch  40  , train loss :  1.3882030402687053 , train_accuracy: 56.95  , valid loss : 2.586329621473948 , val_accuracy: 39.31


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 39.31


100%|██████████| 1973/1973 [06:35<00:00,  4.99it/s]


epoch  41  , train loss :  1.3393697275076766 , train_accuracy: 58.03  , valid loss : 2.425975351995892 , val_accuracy: 42.25


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 42.25


100%|██████████| 1973/1973 [06:25<00:00,  5.12it/s]


epoch  42  , train loss :  1.3336233568248945 , train_accuracy: 57.89  , valid loss : 2.4552696969111762 , val_accuracy: 43.17


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 43.17


100%|██████████| 1973/1973 [06:31<00:00,  5.04it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  43  , train loss :  1.2947553227372222 , train_accuracy: 59.85  , valid loss : 2.4615666131178537 , val_accuracy: 43.06


100%|██████████| 1973/1973 [06:43<00:00,  4.89it/s]


epoch  44  , train loss :  1.2418294906771088 , train_accuracy: 60.76  , valid loss : 2.3766355173455342 , val_accuracy: 44.83


  0%|          | 0/1973 [00:00<?, ?it/s]

save model, current best val_accuracy: 44.83


100%|██████████| 1973/1973 [06:34<00:00,  5.00it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  45  , train loss :  1.1857181183473942 , train_accuracy: 62.8  , valid loss : 2.501040957503849 , val_accuracy: 43.19


100%|██████████| 1973/1973 [06:39<00:00,  4.93it/s]
  0%|          | 0/1973 [00:00<?, ?it/s]

epoch  46  , train loss :  1.1733848714459811 , train_accuracy: 63.0  , valid loss : 2.4356797742181353 , val_accuracy: 43.92


100%|██████████| 1973/1973 [06:32<00:00,  5.03it/s]
  0%|          | 1/1973 [00:00<06:08,  5.36it/s]

epoch  47  , train loss :  1.1408546399734751 , train_accuracy: 63.83  , valid loss : 2.4262655148241254 , val_accuracy: 44.25


100%|██████████| 1973/1973 [05:46<00:00,  5.70it/s]
  0%|          | 1/1973 [00:00<05:41,  5.77it/s]

epoch  48  , train loss :  1.1016402352716739 , train_accuracy: 65.0  , valid loss : 2.5625516823265286 , val_accuracy: 44.67


100%|██████████| 1973/1973 [05:40<00:00,  5.79it/s]


epoch  49  , train loss :  1.0716851581558768 , train_accuracy: 65.82  , valid loss : 2.5937047643793956 , val_accuracy: 44.47


In [None]:
#回復最佳model
cnn = make_model('xception', num_classes=120, pretrained=True, input_size=(299, 299), classifier_factory=make_classifier)
cnn.load_state_dict(torch.load(run_logdir+'_weights.pkl'))
cnn.eval()

In [None]:
#torch.save(cnn, model_name+'.pkl')
torch.save(cnn.state_dict(), model_name+'_weights.pkl')

In [None]:
# test all

In [None]:
#回復最佳model
cnn = make_model('xception', num_classes=120, pretrained=True, input_size=(299, 299), classifier_factory=make_classifier)
cnn.load_state_dict(torch.load('run_2021_01_07-15_10_01_weights_61.pkl'))
cnn.eval()

In [None]:
def test(loaders, model, criterion, use_cuda):

    # monitor test loss and accuracy
    test_loss = 0.
    correct = 0.
    total = 0.

    model.eval()
    for batch_idx, (data, target) in enumerate(loaders):
        # move to GPU
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        # update average test loss 
        test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss.data - test_loss))
        # convert output probabilities to predicted class
        pred = output.data.max(1, keepdim=True)[1]
        # compare predictions to true label
        correct += np.sum(np.squeeze(pred.eq(target.data.view_as(pred))).cpu().numpy())
        total += data.size(0)
            
    print('Test Loss: {:.6f}'.format(test_loss))

    print('Test Accuracy: %2d%% (%2d/%2d)' % (
        100. * correct / total, correct, total))

In [None]:
test_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

In [None]:
test_data = datasets.ImageFolder('data/test', transform=test_transforms)

In [None]:
BATCH_SIZE=8
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
use_cuda = torch.cuda.is_available()
cnn=cnn.cuda()
test(test_loader, cnn, loss_func, use_cuda)

In [None]:
test(train_loader, cnn, loss_func, use_cuda)