In [48]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import time
import glob

import pandas as pd
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Globally allows loading truncated images

from tqdm import tqdm_notebook

In [49]:
label_index = pd.read_csv("chinese_herbal_medicine.csv")
label_index = {k:v for v,k in label_index["category"].to_dict().items()}

In [50]:
train_path = glob.glob('./train_set/*/*.jpg')
np.random.shuffle(train_path)

train_label = [label_index[x.split('/')[-2]] for x in train_path]

test_path = glob.glob('./test_set/*.jpg')
test_path.sort()

In [51]:
len(train_path)

8525

In [66]:
for path in train_path:
    try:
        img = Image.open(path).convert('RGB')
        img.size
    except:
        train_path.remove(path)

In [67]:
len(train_path)

8122

In [68]:
train_label = [label_index[x.split('/')[-2]] for x in train_path]

In [69]:
test_path[:10]

['./test_set/006f490b-c352-414c-be06-0e14a39eb3ee.jpg',
 './test_set/007d7487-413d-4d11-9df1-1c40651f6428.jpg',
 './test_set/00b15eef-38c6-42e0-b1ed-5bbd34d9f20f.jpg',
 './test_set/00d3d2ea-2f66-40b7-b402-b2bcd36cf776.jpg',
 './test_set/01207b72-00c6-43c4-aa14-85e76c138022.jpg',
 './test_set/01385aa6-d1ca-4aa9-831f-a63ab7a02d66.jpg',
 './test_set/01631a84-1657-438e-94bc-3e5de97a7907.jpg',
 './test_set/0172672f-e02d-4809-a40c-b1ec16b25281.jpg',
 './test_set/0185e720-8592-44cb-8332-9e39c9e490ba.jpg',
 './test_set/018b4cd5-77ef-495d-a641-1df43725aa3b.jpg']

In [70]:
len(label_index)

54

In [71]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, *meters):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = ""


    def pr2int(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
            losses.update(loss.item(), input.size(0))
            top1.update(acc, input.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f}'
              .format(top1=top1))
        return top1

def predict(test_loader, model, tta=10):
    # switch to evaluate mode
    model.eval()
    
    test_pred_tta = None
    for _ in range(tta):
        test_pred = []
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):
                input = input.cuda()
                target = target.cuda()

                # compute output
                output = model(input)
                output = F.softmax(output, dim=1)
                output = output.data.cpu().numpy()

                test_pred.append(output)
        test_pred = np.vstack(test_pred)
    
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
        top1.update(acc, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.pr2int(i)

In [86]:
class XFDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        try:
            img = Image.open(self.img_path[index]).convert('RGB')
        except:
            img = Image.new('RGB', (224, 224), color='white')

        label = torch.from_numpy(np.array(self.img_label[index]))
        if self.transform is not None:
            img = self.transform(img)
        
        return img, label
    
    def __len__(self):
        return len(self.img_path)

In [82]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=len(label_index))
model = model.cuda()

In [83]:
train_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[:-500], train_label[:-500], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=10, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[-500:], train_label[-500:], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=10, shuffle=False, num_workers=4, pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(20):
    scheduler.step()
    print('Epoch: ', epoch)

    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(val_loader, model, criterion)
    
    if val_acc.avg.item() > best_acc:
        best_acc = round(val_acc.avg.item(), 2)

Epoch:  0
[  0/763]	Time  0.280 ( 0.280)	Loss 4.0691e+00 (4.0691e+00)	Acc@1   0.00 (  0.00)
[100/763]	Time  0.028 ( 0.029)	Loss 3.3308e+00 (3.7872e+00)	Acc@1  10.00 (  7.43)
[200/763]	Time  0.120 ( 0.029)	Loss 2.8703e+00 (3.4719e+00)	Acc@1  30.00 ( 14.08)
[300/763]	Time  0.171 ( 0.028)	Loss 3.0330e+00 (3.2177e+00)	Acc@1  30.00 ( 20.47)
[400/763]	Time  0.025 ( 0.028)	Loss 1.8518e+00 (3.0140e+00)	Acc@1  50.00 ( 24.14)
[500/763]	Time  0.025 ( 0.028)	Loss 1.8098e+00 (2.8778e+00)	Acc@1  70.00 ( 27.25)
[600/763]	Time  0.028 ( 0.027)	Loss 2.1014e+00 (2.7492e+00)	Acc@1  60.00 ( 30.25)
[700/763]	Time  0.026 ( 0.027)	Loss 2.4719e+00 (2.6548e+00)	Acc@1  30.00 ( 32.52)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 51.200
Epoch:  1
[  0/763]	Time  0.268 ( 0.268)	Loss 1.3543e+00 (1.3543e+00)	Acc@1  60.00 ( 60.00)
[100/763]	Time  0.022 ( 0.032)	Loss 1.6205e+00 (1.8530e+00)	Acc@1  50.00 ( 49.50)
[200/763]	Time  0.028 ( 0.030)	Loss 1.0775e+00 (1.8112e+00)	Acc@1  70.00 ( 51.04)
[300/763]	Time  0.024 ( 0.029)	Loss 1.0627e+00 (1.7675e+00)	Acc@1  90.00 ( 52.19)
[400/763]	Time  0.024 ( 0.029)	Loss 1.8049e+00 (1.7593e+00)	Acc@1  40.00 ( 52.69)
[500/763]	Time  0.027 ( 0.029)	Loss 1.3464e+00 (1.7254e+00)	Acc@1  40.00 ( 53.43)
[600/763]	Time  0.028 ( 0.028)	Loss 2.3905e+00 (1.7023e+00)	Acc@1  40.00 ( 53.76)
[700/763]	Time  0.024 ( 0.028)	Loss 1.9786e+00 (1.6916e+00)	Acc@1  60.00 ( 54.14)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 60.000
Epoch:  2
[  0/763]	Time  0.257 ( 0.257)	Loss 1.2047e+00 (1.2047e+00)	Acc@1  70.00 ( 70.00)
[100/763]	Time  0.027 ( 0.029)	Loss 1.5879e+00 (1.3169e+00)	Acc@1  50.00 ( 64.46)
[200/763]	Time  0.030 ( 0.028)	Loss 9.2711e-01 (1.3696e+00)	Acc@1  70.00 ( 63.48)
[300/763]	Time  0.026 ( 0.028)	Loss 1.3268e+00 (1.3503e+00)	Acc@1  70.00 ( 63.65)
[400/763]	Time  0.025 ( 0.027)	Loss 1.5782e+00 (1.3623e+00)	Acc@1  60.00 ( 63.37)
[500/763]	Time  0.024 ( 0.027)	Loss 1.2918e+00 (1.3691e+00)	Acc@1  60.00 ( 63.29)
[600/763]	Time  0.023 ( 0.027)	Loss 7.2814e-01 (1.3826e+00)	Acc@1  80.00 ( 62.71)
[700/763]	Time  0.033 ( 0.027)	Loss 1.4433e+00 (1.3761e+00)	Acc@1  60.00 ( 62.80)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 64.600
Epoch:  3
[  0/763]	Time  0.444 ( 0.444)	Loss 9.5304e-01 (9.5304e-01)	Acc@1  70.00 ( 70.00)
[100/763]	Time  0.025 ( 0.030)	Loss 8.8548e-01 (1.1500e+00)	Acc@1  70.00 ( 68.02)
[200/763]	Time  0.025 ( 0.028)	Loss 8.4484e-01 (1.1028e+00)	Acc@1  80.00 ( 69.25)
[300/763]	Time  0.025 ( 0.028)	Loss 1.5266e+00 (1.0951e+00)	Acc@1  60.00 ( 69.50)
[400/763]	Time  0.026 ( 0.027)	Loss 1.5764e+00 (1.0960e+00)	Acc@1  40.00 ( 69.73)
[500/763]	Time  0.026 ( 0.027)	Loss 1.9737e-01 (1.0938e+00)	Acc@1 100.00 ( 69.80)
[600/763]	Time  0.026 ( 0.027)	Loss 9.2997e-01 (1.0865e+00)	Acc@1  80.00 ( 70.07)
[700/763]	Time  0.024 ( 0.027)	Loss 7.6612e-01 (1.0878e+00)	Acc@1  80.00 ( 70.07)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 68.600
Epoch:  4
[  0/763]	Time  0.310 ( 0.310)	Loss 7.7532e-01 (7.7532e-01)	Acc@1  80.00 ( 80.00)
[100/763]	Time  0.022 ( 0.033)	Loss 4.6986e-01 (9.1653e-01)	Acc@1  80.00 ( 74.55)
[200/763]	Time  0.028 ( 0.030)	Loss 1.2301e+00 (9.4135e-01)	Acc@1  80.00 ( 73.88)
[300/763]	Time  0.026 ( 0.029)	Loss 3.9588e-01 (9.5421e-01)	Acc@1 100.00 ( 73.75)
[400/763]	Time  0.023 ( 0.028)	Loss 9.7661e-01 (9.3484e-01)	Acc@1  60.00 ( 73.94)
[500/763]	Time  0.025 ( 0.028)	Loss 6.8852e-01 (9.2715e-01)	Acc@1  80.00 ( 74.09)
[600/763]	Time  0.026 ( 0.027)	Loss 7.5543e-01 (9.2445e-01)	Acc@1  80.00 ( 74.16)
[700/763]	Time  0.025 ( 0.028)	Loss 7.9635e-01 (9.1969e-01)	Acc@1  70.00 ( 74.37)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 67.200
Epoch:  5
[  0/763]	Time  0.262 ( 0.262)	Loss 9.1862e-01 (9.1862e-01)	Acc@1  70.00 ( 70.00)
[100/763]	Time  0.025 ( 0.035)	Loss 6.2567e-01 (8.1192e-01)	Acc@1  80.00 ( 77.43)
[200/763]	Time  0.024 ( 0.031)	Loss 1.1895e+00 (8.2263e-01)	Acc@1  70.00 ( 76.82)
[300/763]	Time  0.026 ( 0.030)	Loss 6.8614e-01 (8.0490e-01)	Acc@1  80.00 ( 76.84)
[400/763]	Time  0.027 ( 0.029)	Loss 5.2197e-01 (8.1178e-01)	Acc@1  80.00 ( 76.76)
[500/763]	Time  0.024 ( 0.028)	Loss 1.1788e-01 (8.0821e-01)	Acc@1 100.00 ( 76.67)
[600/763]	Time  0.026 ( 0.028)	Loss 6.6387e-01 (8.0427e-01)	Acc@1  90.00 ( 76.84)
[700/763]	Time  0.022 ( 0.028)	Loss 3.3604e-01 (8.1190e-01)	Acc@1  80.00 ( 76.69)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 75.800
Epoch:  6
[  0/763]	Time  0.422 ( 0.422)	Loss 6.0120e-01 (6.0120e-01)	Acc@1  80.00 ( 80.00)
[100/763]	Time  0.028 ( 0.033)	Loss 5.2314e-01 (6.3812e-01)	Acc@1  80.00 ( 81.78)
[200/763]	Time  0.026 ( 0.030)	Loss 5.8560e-01 (6.9037e-01)	Acc@1  80.00 ( 80.45)
[300/763]	Time  0.022 ( 0.030)	Loss 5.4326e-01 (6.8081e-01)	Acc@1  80.00 ( 81.10)
[400/763]	Time  0.025 ( 0.029)	Loss 8.9760e-01 (6.7259e-01)	Acc@1  80.00 ( 81.47)
[500/763]	Time  0.022 ( 0.029)	Loss 3.6175e-01 (6.7185e-01)	Acc@1  90.00 ( 81.12)
[600/763]	Time  0.024 ( 0.029)	Loss 8.3030e-01 (6.9241e-01)	Acc@1  80.00 ( 80.37)
[700/763]	Time  0.025 ( 0.029)	Loss 1.3354e+00 (6.9447e-01)	Acc@1  60.00 ( 80.40)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 69.200
Epoch:  7
[  0/763]	Time  0.280 ( 0.280)	Loss 6.7122e-01 (6.7122e-01)	Acc@1  80.00 ( 80.00)
[100/763]	Time  0.028 ( 0.028)	Loss 3.7805e-01 (5.7931e-01)	Acc@1  90.00 ( 83.17)
[200/763]	Time  0.025 ( 0.027)	Loss 8.5391e-01 (5.5207e-01)	Acc@1  70.00 ( 83.58)
[300/763]	Time  0.025 ( 0.029)	Loss 1.2158e+00 (5.5693e-01)	Acc@1  60.00 ( 83.59)
[400/763]	Time  0.026 ( 0.028)	Loss 1.3337e+00 (5.4136e-01)	Acc@1  90.00 ( 83.94)
[500/763]	Time  0.032 ( 0.028)	Loss 5.8981e-01 (5.4663e-01)	Acc@1  90.00 ( 83.77)
[600/763]	Time  0.025 ( 0.028)	Loss 2.8530e-01 (5.5021e-01)	Acc@1  90.00 ( 83.69)
[700/763]	Time  0.026 ( 0.028)	Loss 1.0885e+00 (5.4219e-01)	Acc@1  80.00 ( 84.11)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 76.200
Epoch:  8
[  0/763]	Time  0.266 ( 0.266)	Loss 7.4283e-01 (7.4283e-01)	Acc@1  80.00 ( 80.00)
[100/763]	Time  0.026 ( 0.028)	Loss 1.7537e-01 (5.0845e-01)	Acc@1 100.00 ( 84.65)
[200/763]	Time  0.022 ( 0.028)	Loss 4.0385e-01 (4.9886e-01)	Acc@1  90.00 ( 84.98)
[300/763]	Time  0.025 ( 0.028)	Loss 8.5943e-01 (4.7970e-01)	Acc@1  60.00 ( 85.91)
[400/763]	Time  0.030 ( 0.028)	Loss 3.4421e-01 (4.6605e-01)	Acc@1  90.00 ( 86.63)
[500/763]	Time  0.025 ( 0.028)	Loss 9.0612e-01 (4.7277e-01)	Acc@1  70.00 ( 86.27)
[600/763]	Time  0.029 ( 0.028)	Loss 3.9275e-01 (4.8043e-01)	Acc@1  90.00 ( 85.82)
[700/763]	Time  0.026 ( 0.028)	Loss 4.0308e-01 (4.8174e-01)	Acc@1  90.00 ( 85.91)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 75.200
Epoch:  9
[  0/763]	Time  0.350 ( 0.350)	Loss 1.4037e-01 (1.4037e-01)	Acc@1 100.00 (100.00)
[100/763]	Time  0.024 ( 0.031)	Loss 1.6095e-01 (4.1281e-01)	Acc@1 100.00 ( 88.22)
[200/763]	Time  0.032 ( 0.029)	Loss 5.3435e-01 (4.1404e-01)	Acc@1  70.00 ( 87.76)
[300/763]	Time  0.022 ( 0.028)	Loss 1.6529e-01 (4.0936e-01)	Acc@1 100.00 ( 88.07)
[400/763]	Time  0.031 ( 0.028)	Loss 1.6344e+00 (4.2420e-01)	Acc@1  60.00 ( 87.81)
[500/763]	Time  0.024 ( 0.028)	Loss 3.1651e-01 (4.3008e-01)	Acc@1  90.00 ( 87.45)
[600/763]	Time  0.027 ( 0.027)	Loss 2.1770e-01 (4.3468e-01)	Acc@1 100.00 ( 87.22)
[700/763]	Time  0.023 ( 0.027)	Loss 4.5741e-01 (4.3364e-01)	Acc@1  80.00 ( 87.12)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 73.400
Epoch:  10
[  0/763]	Time  0.272 ( 0.272)	Loss 6.8335e-01 (6.8335e-01)	Acc@1  80.00 ( 80.00)
[100/763]	Time  0.027 ( 0.029)	Loss 9.9385e-01 (3.8495e-01)	Acc@1  70.00 ( 88.81)
[200/763]	Time  0.027 ( 0.028)	Loss 3.7084e-01 (3.6849e-01)	Acc@1  90.00 ( 88.91)
[300/763]	Time  0.030 ( 0.028)	Loss 2.3021e-01 (4.0337e-01)	Acc@1 100.00 ( 87.91)
[400/763]	Time  0.034 ( 0.028)	Loss 7.4769e-01 (4.1197e-01)	Acc@1  80.00 ( 87.66)
[500/763]	Time  0.030 ( 0.028)	Loss 4.9439e-03 (4.0283e-01)	Acc@1 100.00 ( 88.18)
[600/763]	Time  0.038 ( 0.028)	Loss 1.1385e-01 (4.0294e-01)	Acc@1  90.00 ( 88.02)
[700/763]	Time  0.026 ( 0.028)	Loss 1.7387e-01 (4.0155e-01)	Acc@1  90.00 ( 87.96)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 76.800
Epoch:  11
[  0/763]	Time  0.287 ( 0.287)	Loss 8.6691e-02 (8.6691e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.159 ( 0.032)	Loss 5.4101e-01 (3.1472e-01)	Acc@1  70.00 ( 90.20)
[200/763]	Time  0.028 ( 0.029)	Loss 2.7388e-01 (2.9242e-01)	Acc@1  90.00 ( 91.09)
[300/763]	Time  0.024 ( 0.028)	Loss 1.4696e-01 (3.0125e-01)	Acc@1 100.00 ( 91.00)
[400/763]	Time  0.025 ( 0.028)	Loss 1.4508e-01 (3.0089e-01)	Acc@1 100.00 ( 90.95)
[500/763]	Time  0.023 ( 0.028)	Loss 4.0768e-01 (2.9382e-01)	Acc@1  90.00 ( 91.12)
[600/763]	Time  0.026 ( 0.027)	Loss 2.6833e-01 (2.9592e-01)	Acc@1  90.00 ( 91.10)
[700/763]	Time  0.024 ( 0.027)	Loss 3.5922e-01 (3.0312e-01)	Acc@1  80.00 ( 91.00)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 77.800
Epoch:  12
[  0/763]	Time  0.307 ( 0.307)	Loss 5.9140e-02 (5.9140e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.027 ( 0.030)	Loss 3.6131e-01 (2.4005e-01)	Acc@1  90.00 ( 93.07)
[200/763]	Time  0.024 ( 0.028)	Loss 3.1794e-01 (2.5326e-01)	Acc@1  90.00 ( 92.39)
[300/763]	Time  0.024 ( 0.028)	Loss 4.9144e-02 (2.6090e-01)	Acc@1 100.00 ( 92.29)
[400/763]	Time  0.029 ( 0.028)	Loss 5.4448e-01 (2.7511e-01)	Acc@1  90.00 ( 91.67)
[500/763]	Time  0.028 ( 0.028)	Loss 3.6184e-01 (2.7062e-01)	Acc@1  80.00 ( 91.68)
[600/763]	Time  0.031 ( 0.028)	Loss 2.7262e-01 (2.7568e-01)	Acc@1  90.00 ( 91.55)
[700/763]	Time  0.027 ( 0.028)	Loss 4.7171e-01 (2.7443e-01)	Acc@1  90.00 ( 91.73)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 76.600
Epoch:  13
[  0/763]	Time  0.258 ( 0.258)	Loss 1.9827e-01 (1.9827e-01)	Acc@1  90.00 ( 90.00)
[100/763]	Time  0.029 ( 0.031)	Loss 2.4568e-01 (2.4959e-01)	Acc@1  90.00 ( 92.67)
[200/763]	Time  0.027 ( 0.029)	Loss 4.3320e-01 (2.3579e-01)	Acc@1  90.00 ( 93.18)
[300/763]	Time  0.028 ( 0.028)	Loss 7.2648e-02 (2.4254e-01)	Acc@1 100.00 ( 92.82)
[400/763]	Time  0.024 ( 0.029)	Loss 5.5048e-03 (2.4367e-01)	Acc@1 100.00 ( 92.74)
[500/763]	Time  0.023 ( 0.028)	Loss 5.0476e-01 (2.4983e-01)	Acc@1  90.00 ( 92.61)
[600/763]	Time  0.029 ( 0.028)	Loss 1.1394e+00 (2.5612e-01)	Acc@1  80.00 ( 92.48)
[700/763]	Time  0.021 ( 0.028)	Loss 7.2782e-02 (2.6730e-01)	Acc@1 100.00 ( 92.23)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 78.800
Epoch:  14
[  0/763]	Time  0.284 ( 0.284)	Loss 9.3699e-02 (9.3699e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.028 ( 0.031)	Loss 7.2797e-01 (1.9805e-01)	Acc@1  70.00 ( 94.26)
[200/763]	Time  0.030 ( 0.029)	Loss 2.1312e-01 (2.1036e-01)	Acc@1 100.00 ( 94.28)
[300/763]	Time  0.024 ( 0.029)	Loss 7.0070e-02 (2.0815e-01)	Acc@1 100.00 ( 94.12)
[400/763]	Time  0.025 ( 0.029)	Loss 1.4800e-01 (2.2250e-01)	Acc@1 100.00 ( 93.72)
[500/763]	Time  0.056 ( 0.028)	Loss 1.8086e-01 (2.2117e-01)	Acc@1 100.00 ( 93.69)
[600/763]	Time  0.028 ( 0.028)	Loss 4.2974e-01 (2.1985e-01)	Acc@1  80.00 ( 93.69)
[700/763]	Time  0.022 ( 0.028)	Loss 3.4819e-01 (2.1937e-01)	Acc@1  80.00 ( 93.61)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 77.800
Epoch:  15
[  0/763]	Time  0.304 ( 0.304)	Loss 3.2110e-02 (3.2110e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.026 ( 0.029)	Loss 8.8859e-02 (1.5692e-01)	Acc@1 100.00 ( 95.15)
[200/763]	Time  0.026 ( 0.028)	Loss 9.6691e-02 (1.6744e-01)	Acc@1 100.00 ( 94.93)
[300/763]	Time  0.024 ( 0.027)	Loss 2.8211e-01 (1.7563e-01)	Acc@1  90.00 ( 94.78)
[400/763]	Time  0.024 ( 0.028)	Loss 4.5133e-02 (1.8172e-01)	Acc@1 100.00 ( 94.59)
[500/763]	Time  0.025 ( 0.028)	Loss 1.0679e-01 (1.7807e-01)	Acc@1 100.00 ( 94.81)
[600/763]	Time  0.025 ( 0.028)	Loss 1.4572e-01 (1.7651e-01)	Acc@1  90.00 ( 94.81)
[700/763]	Time  0.025 ( 0.028)	Loss 1.7495e-01 (1.8202e-01)	Acc@1 100.00 ( 94.66)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 79.400
Epoch:  16
[  0/763]	Time  0.262 ( 0.262)	Loss 4.0298e-02 (4.0298e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.025 ( 0.031)	Loss 1.3506e-01 (1.3146e-01)	Acc@1 100.00 ( 95.74)
[200/763]	Time  0.023 ( 0.028)	Loss 1.3349e-01 (1.4464e-01)	Acc@1  90.00 ( 95.42)
[300/763]	Time  0.024 ( 0.028)	Loss 2.0964e-01 (1.4551e-01)	Acc@1  90.00 ( 95.42)
[400/763]	Time  0.025 ( 0.028)	Loss 3.8734e-01 (1.5165e-01)	Acc@1  90.00 ( 95.21)
[500/763]	Time  0.028 ( 0.028)	Loss 1.2457e-01 (1.6077e-01)	Acc@1  90.00 ( 94.99)
[600/763]	Time  0.022 ( 0.028)	Loss 1.4505e-01 (1.6493e-01)	Acc@1  90.00 ( 94.84)
[700/763]	Time  0.025 ( 0.028)	Loss 1.1527e-01 (1.6684e-01)	Acc@1  90.00 ( 94.86)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 78.800
Epoch:  17
[  0/763]	Time  0.261 ( 0.261)	Loss 2.4832e-02 (2.4832e-02)	Acc@1 100.00 (100.00)
[100/763]	Time  0.026 ( 0.029)	Loss 5.6986e-02 (1.2706e-01)	Acc@1 100.00 ( 97.13)
[200/763]	Time  0.025 ( 0.029)	Loss 2.1062e-01 (1.3177e-01)	Acc@1  90.00 ( 96.17)
[300/763]	Time  0.027 ( 0.028)	Loss 1.4096e-01 (1.4132e-01)	Acc@1 100.00 ( 95.65)
[400/763]	Time  0.026 ( 0.027)	Loss 1.2485e-01 (1.4123e-01)	Acc@1  90.00 ( 95.49)
[500/763]	Time  0.029 ( 0.027)	Loss 9.0011e-02 (1.4930e-01)	Acc@1 100.00 ( 95.31)
[600/763]	Time  0.021 ( 0.028)	Loss 3.6787e-01 (1.5649e-01)	Acc@1  90.00 ( 94.99)
[700/763]	Time  0.025 ( 0.028)	Loss 2.2477e-01 (1.5583e-01)	Acc@1  90.00 ( 95.05)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 77.200
Epoch:  18
[  0/763]	Time  0.291 ( 0.291)	Loss 1.8693e-01 (1.8693e-01)	Acc@1  90.00 ( 90.00)
[100/763]	Time  0.025 ( 0.031)	Loss 1.7509e-01 (1.9842e-01)	Acc@1  90.00 ( 93.96)
[200/763]	Time  0.025 ( 0.030)	Loss 3.7665e-01 (1.6763e-01)	Acc@1  80.00 ( 94.88)
[300/763]	Time  0.026 ( 0.029)	Loss 2.0604e-02 (1.6482e-01)	Acc@1 100.00 ( 95.12)
[400/763]	Time  0.025 ( 0.028)	Loss 2.3494e-02 (1.6603e-01)	Acc@1 100.00 ( 95.06)
[500/763]	Time  0.025 ( 0.028)	Loss 4.0113e-02 (1.5512e-01)	Acc@1 100.00 ( 95.49)
[600/763]	Time  0.020 ( 0.028)	Loss 1.6110e-01 (1.5024e-01)	Acc@1 100.00 ( 95.76)
[700/763]	Time  0.027 ( 0.028)	Loss 7.6947e-02 (1.5441e-01)	Acc@1 100.00 ( 95.61)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 77.600
Epoch:  19
[  0/763]	Time  0.311 ( 0.311)	Loss 2.1786e-01 (2.1786e-01)	Acc@1  90.00 ( 90.00)
[100/763]	Time  0.029 ( 0.030)	Loss 7.6430e-02 (1.2448e-01)	Acc@1 100.00 ( 96.24)
[200/763]	Time  0.031 ( 0.028)	Loss 5.2715e-03 (1.0296e-01)	Acc@1 100.00 ( 97.26)
[300/763]	Time  0.023 ( 0.028)	Loss 6.3700e-02 (1.0464e-01)	Acc@1 100.00 ( 97.21)
[400/763]	Time  0.025 ( 0.028)	Loss 2.3657e-03 (1.0401e-01)	Acc@1 100.00 ( 97.21)
[500/763]	Time  0.025 ( 0.028)	Loss 4.1642e-02 (1.0472e-01)	Acc@1 100.00 ( 97.11)
[600/763]	Time  0.025 ( 0.028)	Loss 2.3447e-02 (1.0926e-01)	Acc@1 100.00 ( 96.94)
[700/763]	Time  0.029 ( 0.028)	Loss 1.1835e-01 (1.1077e-01)	Acc@1  90.00 ( 96.78)


  0%|          | 0/50 [00:00<?, ?it/s]

 * Acc@1 78.200


In [84]:
len(train_path)

8122

In [89]:
test_path = list("./test_set/" + pd.read_csv("example_A.csv")["ImageID"])

test_loader = torch.utils.data.DataLoader(
    XFDataset(test_path, [0] * len(test_path), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

val_label = pd.DataFrame()
val_label['ImageID'] = [x.split('/')[-1] for x in test_path]
val_label['label'] = predict(test_loader, model, 1).argmax(1)
val_label.to_csv('submit.csv', index=None)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):


  0%|          | 0/46 [00:00<?, ?it/s]

In [90]:
val_label

Unnamed: 0,ImageID,label
0,006f490b-c352-414c-be06-0e14a39eb3ee.jpg,5
1,007d7487-413d-4d11-9df1-1c40651f6428.jpg,15
2,00b15eef-38c6-42e0-b1ed-5bbd34d9f20f.jpg,18
3,00d3d2ea-2f66-40b7-b402-b2bcd36cf776.jpg,1
4,01207b72-00c6-43c4-aa14-85e76c138022.jpg,53
...,...,...
1820,ff3d0831-0110-4106-b76a-c9f5c8291a18.jpg,18
1821,ff5c00d4-91eb-4347-84c2-2824654eb625.jpg,25
1822,ff74d365-3be6-49e9-878f-806e8c8b83de.jpg,34
1823,ff864ef5-5cb2-4d63-a6d5-d6f8980f6f23.jpg,37
