In [9]:
from __future__ import print_function

import os
import random
import shutil
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
from utils import load_model, AverageMeter, accuracy
import cv2

from config import cfg
from dataset import *
from utils import *

# 环境变量设置

In [10]:
os.chdir("./")   #修改当前工作目录

# Use CUDA
use_cuda = torch.cuda.is_available()

seed = 11037
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

model_file = './tf_efficientnet_b3_ns93.14395092845393.pth.tar'

# 推理

In [11]:
@torch.no_grad()
def infer(testloader, model):
    losses = AverageMeter()
    accs = AverageMeter()
    model.eval()

    predicts = []
    img_names = []

    bar = tqdm(enumerate(testloader), total=len(testloader))
    for steps, (inputs, names) in bar:
        inputs = inputs.to('cuda', dtype=torch.float)

        outputs = model(inputs)
        # loss = cross_entropy(outputs, labels)
        # acc = accuracy(outputs, labels)

        # losses.update(loss.item(), inputs.size(0))
        # accs.update(acc[0].item(), inputs.size(0))
        outputs = torch.softmax(outputs, dim=1)
        target = torch.argmax(outputs,dim=1)



        img_names.extend(names)
        predicts.extend(target.detach().cpu().numpy().tolist())

        print(len(img_names))
        print(len(predicts))

        # bar.set_postfix(test_loss=losses.avg, test_acc=accs.avg,
        #                 lr=optimizer.state_dict()['param_groups'][0]['lr'])

    return losses.avg, accs.avg

In [12]:
testset = MyDataset(data_dir='./data', mode='test', transform=transform_test)
testloader = data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=0)

# Model
model_dict = torch.load(model_file)
model = load_model(cfg['model'],pretrained=False)
model.load_state_dict(model_dict['state_dict'])
model = model.cuda()

train_loss, train_acc = infer(testloader, model)



# generator parameters: 10726972


  0%|          | 0/1250 [00:00<?, ?it/s]

16
16


  0%|          | 2/1250 [00:00<01:51, 11.24it/s]

32
32


  0%|          | 4/1250 [00:00<01:46, 11.74it/s]

48
48
64
64
80
80


  0%|          | 6/1250 [00:00<01:38, 12.67it/s]

96
96
112
112


  1%|          | 8/1250 [00:00<01:41, 12.26it/s]

128
128


  1%|          | 10/1250 [00:00<01:40, 12.37it/s]

144
144
160
160
176
176


  1%|          | 12/1250 [00:00<01:36, 12.84it/s]

192
192
208
208


  1%|▏         | 16/1250 [00:01<01:40, 12.29it/s]

224
224
240
240
256
256


  1%|▏         | 18/1250 [00:01<01:36, 12.72it/s]

272
272
288
288
304
304


  2%|▏         | 22/1250 [00:01<01:41, 12.15it/s]

320
320
336
336
352
352


  2%|▏         | 24/1250 [00:01<01:42, 11.93it/s]

368
368
384
384
400
400


  2%|▏         | 28/1250 [00:02<01:43, 11.77it/s]

416
416
432
432
448
448


  2%|▏         | 30/1250 [00:02<01:44, 11.66it/s]

464
464
480
480
496
496


  3%|▎         | 34/1250 [00:02<01:41, 11.97it/s]

512
512
528
528
544
544


  3%|▎         | 36/1250 [00:02<01:39, 12.23it/s]

560
560
576
576
592
592


  3%|▎         | 40/1250 [00:03<01:39, 12.14it/s]

608
608
624
624
640
640


  4%|▎         | 44/1250 [00:03<01:32, 13.09it/s]

656
656
672
672
688
688
704
704


  4%|▎         | 46/1250 [00:03<01:36, 12.47it/s]

720
720
736
736
752
752


  4%|▍         | 50/1250 [00:04<01:30, 13.21it/s]

768
768
784
784
800
800
816
816


  4%|▍         | 54/1250 [00:04<01:37, 12.27it/s]

832
832
848
848
864
864


  4%|▍         | 56/1250 [00:04<01:33, 12.80it/s]

880
880
896
896
912
912


  5%|▍         | 60/1250 [00:04<01:38, 12.12it/s]

928
928
944
944
960
960


  5%|▍         | 62/1250 [00:05<01:32, 12.85it/s]

976
976
992
992
1008
1008


  5%|▌         | 66/1250 [00:05<01:34, 12.48it/s]

1024
1024
1040
1040
1056
1056


  5%|▌         | 68/1250 [00:05<01:36, 12.25it/s]

1072
1072
1088
1088


  6%|▌         | 72/1250 [00:05<01:33, 12.54it/s]

1104
1104
1120
1120
1136
1136
1152
1152


  6%|▌         | 74/1250 [00:06<01:43, 11.31it/s]

1168
1168
1184
1184
1200
1200


  6%|▌         | 78/1250 [00:06<01:32, 12.63it/s]

1216
1216
1232
1232
1248
1248


  6%|▋         | 80/1250 [00:06<01:40, 11.64it/s]

1264
1264
1280
1280
1296
1296


  7%|▋         | 82/1250 [00:06<01:34, 12.41it/s]

1312
1312
1328
1328


  7%|▋         | 84/1250 [00:07<02:47,  6.97it/s]

1344
1344
1360
1360


  7%|▋         | 86/1250 [00:07<03:29,  5.55it/s]

1376
1376


  7%|▋         | 87/1250 [00:08<03:49,  5.07it/s]

1392
1392


  7%|▋         | 88/1250 [00:08<03:54,  4.95it/s]

1408
1408


  7%|▋         | 89/1250 [00:08<04:26,  4.36it/s]

1424
1424


  7%|▋         | 90/1250 [00:08<04:49,  4.00it/s]

1440
1440


  7%|▋         | 91/1250 [00:09<04:56,  3.91it/s]

1456
1456


  7%|▋         | 92/1250 [00:09<05:14,  3.68it/s]

1472
1472


  7%|▋         | 93/1250 [00:09<05:12,  3.71it/s]

1488
1488


  8%|▊         | 94/1250 [00:10<04:58,  3.87it/s]

1504
1504


  8%|▊         | 95/1250 [00:10<05:01,  3.83it/s]

1520
1520


  8%|▊         | 96/1250 [00:10<04:55,  3.91it/s]

1536
1536


  8%|▊         | 97/1250 [00:10<04:55,  3.90it/s]

1552
1552


  8%|▊         | 98/1250 [00:11<04:57,  3.88it/s]

1568
1568


  8%|▊         | 99/1250 [00:11<04:56,  3.89it/s]

1584
1584


  8%|▊         | 100/1250 [00:11<05:02,  3.80it/s]

1600
1600


  8%|▊         | 101/1250 [00:11<05:01,  3.81it/s]

1616
1616


  8%|▊         | 102/1250 [00:12<05:08,  3.72it/s]

1632
1632


  8%|▊         | 103/1250 [00:12<04:56,  3.87it/s]

1648
1648


  8%|▊         | 104/1250 [00:12<04:49,  3.96it/s]

1664
1664


  8%|▊         | 105/1250 [00:12<05:18,  3.59it/s]

1680
1680


  8%|▊         | 106/1250 [00:13<05:16,  3.61it/s]

1696
1696


  9%|▊         | 107/1250 [00:13<05:05,  3.74it/s]

1712
1712


  9%|▊         | 108/1250 [00:13<05:06,  3.73it/s]

1728
1728


  9%|▊         | 109/1250 [00:13<05:02,  3.77it/s]

1744
1744


  9%|▉         | 110/1250 [00:14<04:56,  3.84it/s]

1760
1760


  9%|▉         | 111/1250 [00:14<02:29,  7.62it/s]

1776
1776





KeyboardInterrupt: 