In [6]:
import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import time
import glob

import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm_notebook

In [25]:
train_path = glob.glob('./train/*/*/*.jpg')
np.random.shuffle(train_path)

train_label = [int(x.split('/')[-3]) for x in train_path]

test_path = glob.glob('./test/*.jpg')
test_path.sort()

In [26]:
len(train_path)

22312

In [16]:
set(train_label)

{0, 1, 2, 3, 4}

In [28]:
test_path[:10]

['./test/001.jpg',
 './test/002.jpg',
 './test/003.jpg',
 './test/004.jpg',
 './test/005.jpg',
 './test/006.jpg',
 './test/007.jpg',
 './test/008.jpg',
 './test/009.jpg',
 './test/010.jpg']

In [17]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, *meters):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = ""


    def pr2int(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
            losses.update(loss.item(), input.size(0))
            top1.update(acc, input.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f}'
              .format(top1=top1))
        return top1

def predict(test_loader, model, tta=10):
    # switch to evaluate mode
    model.eval()
    
    test_pred_tta = None
    for _ in range(tta):
        test_pred = []
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):
                input = input.cuda()
                target = target.cuda()

                # compute output
                output = model(input)
                output = F.softmax(output, dim=1)
                output = output.data.cpu().numpy()

                test_pred.append(output)
        test_pred = np.vstack(test_pred)
    
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
        top1.update(acc, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.pr2int(i)

In [18]:
class XFDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, torch.from_numpy(np.array(self.img_label[index]))
    
    def __len__(self):
        return len(self.img_path)

In [20]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=5)
model = model.cuda()

In [22]:
train_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[:-500], train_label[:-500], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ColorJitter(brightness=.5, hue=.3),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=20, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[-500:], train_label[-500:], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=20, shuffle=False, num_workers=4, pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(10):
    scheduler.step()
    print('Epoch: ', epoch)

    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(val_loader, model, criterion)
    
    if val_acc.avg.item() > best_acc:
        best_acc = round(val_acc.avg.item(), 2)

Epoch:  0
[   0/1091]	Time  0.324 ( 0.324)	Loss 4.5023e-03 (4.5023e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.044 ( 0.046)	Loss 7.8433e-02 (6.3462e-02)	Acc@1  95.00 ( 97.67)
[ 200/1091]	Time  0.044 ( 0.044)	Loss 6.1962e-03 (6.5124e-02)	Acc@1 100.00 ( 97.84)
[ 300/1091]	Time  0.045 ( 0.044)	Loss 1.7800e-01 (6.1890e-02)	Acc@1  95.00 ( 98.02)
[ 400/1091]	Time  0.044 ( 0.044)	Loss 2.4579e-03 (5.9559e-02)	Acc@1 100.00 ( 98.18)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 2.0658e-03 (5.5116e-02)	Acc@1 100.00 ( 98.31)
[ 600/1091]	Time  0.043 ( 0.044)	Loss 2.6326e-03 (5.1129e-02)	Acc@1 100.00 ( 98.40)
[ 700/1091]	Time  0.043 ( 0.043)	Loss 5.6864e-03 (5.3521e-02)	Acc@1 100.00 ( 98.35)
[ 800/1091]	Time  0.045 ( 0.043)	Loss 1.4429e-01 (5.1473e-02)	Acc@1  90.00 ( 98.37)
[ 900/1091]	Time  0.043 ( 0.043)	Loss 6.3551e-03 (5.1053e-02)	Acc@1 100.00 ( 98.38)
[1000/1091]	Time  0.042 ( 0.043)	Loss 2.9711e-03 (4.9586e-02)	Acc@1 100.00 ( 98.43)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.800
Epoch:  1
[   0/1091]	Time  0.244 ( 0.244)	Loss 9.8599e-02 (9.8599e-02)	Acc@1  95.00 ( 95.00)
[ 100/1091]	Time  0.043 ( 0.045)	Loss 2.1680e-03 (3.1402e-02)	Acc@1 100.00 ( 98.66)
[ 200/1091]	Time  0.045 ( 0.044)	Loss 1.1284e-04 (2.5259e-02)	Acc@1 100.00 ( 99.10)
[ 300/1091]	Time  0.048 ( 0.044)	Loss 1.1205e-04 (2.6241e-02)	Acc@1 100.00 ( 99.02)
[ 400/1091]	Time  0.043 ( 0.044)	Loss 1.4951e-02 (3.1412e-02)	Acc@1 100.00 ( 98.92)
[ 500/1091]	Time  0.044 ( 0.044)	Loss 1.9477e-01 (3.7644e-02)	Acc@1  95.00 ( 98.70)
[ 600/1091]	Time  0.040 ( 0.043)	Loss 2.2300e-03 (3.7396e-02)	Acc@1 100.00 ( 98.70)
[ 700/1091]	Time  0.039 ( 0.043)	Loss 1.9081e-02 (3.5318e-02)	Acc@1 100.00 ( 98.77)
[ 800/1091]	Time  0.043 ( 0.043)	Loss 1.0084e-01 (3.4443e-02)	Acc@1  95.00 ( 98.83)
[ 900/1091]	Time  0.042 ( 0.043)	Loss 3.8712e-03 (3.3015e-02)	Acc@1 100.00 ( 98.88)
[1000/1091]	Time  0.044 ( 0.043)	Loss 3.4781e-04 (3.2336e-02)	Acc@1 100.00 ( 98.92)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 100.000
Epoch:  2
[   0/1091]	Time  0.259 ( 0.259)	Loss 1.4724e-03 (1.4724e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.045 ( 0.045)	Loss 1.5199e-02 (1.5595e-02)	Acc@1 100.00 ( 99.60)
[ 200/1091]	Time  0.042 ( 0.044)	Loss 4.7581e-04 (1.6610e-02)	Acc@1 100.00 ( 99.53)
[ 300/1091]	Time  0.044 ( 0.044)	Loss 1.9028e-02 (1.9979e-02)	Acc@1 100.00 ( 99.40)
[ 400/1091]	Time  0.043 ( 0.044)	Loss 1.0502e-03 (2.2769e-02)	Acc@1 100.00 ( 99.29)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 4.2360e-03 (2.0434e-02)	Acc@1 100.00 ( 99.35)
[ 600/1091]	Time  0.040 ( 0.044)	Loss 5.9249e-05 (1.9844e-02)	Acc@1 100.00 ( 99.39)
[ 700/1091]	Time  0.047 ( 0.044)	Loss 1.8846e-05 (1.7879e-02)	Acc@1 100.00 ( 99.46)
[ 800/1091]	Time  0.043 ( 0.044)	Loss 5.5814e-03 (1.8370e-02)	Acc@1 100.00 ( 99.41)
[ 900/1091]	Time  0.047 ( 0.043)	Loss 1.2090e-02 (1.9222e-02)	Acc@1 100.00 ( 99.38)
[1000/1091]	Time  0.043 ( 0.043)	Loss 1.3726e-03 (2.1988e-02)	Acc@1 100.00 ( 99.28)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.000
Epoch:  3
[   0/1091]	Time  0.256 ( 0.256)	Loss 4.7306e-02 (4.7306e-02)	Acc@1  95.00 ( 95.00)
[ 100/1091]	Time  0.041 ( 0.045)	Loss 6.9081e-06 (1.3448e-02)	Acc@1 100.00 ( 99.41)
[ 200/1091]	Time  0.043 ( 0.044)	Loss 9.6763e-04 (1.4315e-02)	Acc@1 100.00 ( 99.45)
[ 300/1091]	Time  0.036 ( 0.044)	Loss 9.7373e-02 (1.1404e-02)	Acc@1  95.00 ( 99.58)
[ 400/1091]	Time  0.045 ( 0.044)	Loss 6.9967e-04 (1.2339e-02)	Acc@1 100.00 ( 99.56)
[ 500/1091]	Time  0.042 ( 0.044)	Loss 2.7265e-03 (1.2358e-02)	Acc@1 100.00 ( 99.55)
[ 600/1091]	Time  0.042 ( 0.044)	Loss 1.2706e-04 (1.2163e-02)	Acc@1 100.00 ( 99.56)
[ 700/1091]	Time  0.043 ( 0.044)	Loss 4.7863e-03 (1.2750e-02)	Acc@1 100.00 ( 99.54)
[ 800/1091]	Time  0.044 ( 0.044)	Loss 1.5516e-03 (1.3346e-02)	Acc@1 100.00 ( 99.51)
[ 900/1091]	Time  0.043 ( 0.044)	Loss 4.5972e-03 (1.4055e-02)	Acc@1 100.00 ( 99.48)
[1000/1091]	Time  0.044 ( 0.044)	Loss 1.3417e-02 (1.3932e-02)	Acc@1 100.00 ( 99.49)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.800
Epoch:  4
[   0/1091]	Time  0.262 ( 0.262)	Loss 1.2434e-03 (1.2434e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.044 ( 0.045)	Loss 1.0617e-04 (4.5660e-03)	Acc@1 100.00 ( 99.85)
[ 200/1091]	Time  0.043 ( 0.044)	Loss 3.6592e-04 (3.7412e-03)	Acc@1 100.00 ( 99.93)
[ 300/1091]	Time  0.047 ( 0.044)	Loss 5.0360e-05 (4.0069e-03)	Acc@1 100.00 ( 99.93)
[ 400/1091]	Time  0.044 ( 0.044)	Loss 6.4318e-05 (5.0542e-03)	Acc@1 100.00 ( 99.86)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 9.8798e-04 (4.7829e-03)	Acc@1 100.00 ( 99.85)
[ 600/1091]	Time  0.042 ( 0.044)	Loss 3.8230e-05 (4.4126e-03)	Acc@1 100.00 ( 99.86)
[ 700/1091]	Time  0.043 ( 0.044)	Loss 2.6449e-04 (5.7047e-03)	Acc@1 100.00 ( 99.84)
[ 800/1091]	Time  0.043 ( 0.044)	Loss 3.7706e-02 (1.0854e-02)	Acc@1 100.00 ( 99.66)
[ 900/1091]	Time  0.044 ( 0.044)	Loss 2.2186e-02 (1.3616e-02)	Acc@1 100.00 ( 99.55)
[1000/1091]	Time  0.043 ( 0.044)	Loss 7.5228e-03 (1.4314e-02)	Acc@1 100.00 ( 99.52)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 100.000
Epoch:  5
[   0/1091]	Time  0.273 ( 0.273)	Loss 3.7738e-03 (3.7738e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.043 ( 0.046)	Loss 2.5157e-04 (5.4340e-03)	Acc@1 100.00 ( 99.90)
[ 200/1091]	Time  0.042 ( 0.044)	Loss 5.3610e-04 (9.9741e-03)	Acc@1 100.00 ( 99.65)
[ 300/1091]	Time  0.043 ( 0.044)	Loss 9.2436e-03 (8.5162e-03)	Acc@1 100.00 ( 99.72)
[ 400/1091]	Time  0.044 ( 0.044)	Loss 7.5036e-03 (1.0029e-02)	Acc@1 100.00 ( 99.69)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 2.5220e-03 (1.0865e-02)	Acc@1 100.00 ( 99.65)
[ 600/1091]	Time  0.044 ( 0.044)	Loss 2.9781e-04 (1.1034e-02)	Acc@1 100.00 ( 99.63)
[ 700/1091]	Time  0.044 ( 0.044)	Loss 1.9654e-04 (1.0148e-02)	Acc@1 100.00 ( 99.66)
[ 800/1091]	Time  0.042 ( 0.044)	Loss 9.3135e-03 (1.0139e-02)	Acc@1 100.00 ( 99.64)
[ 900/1091]	Time  0.041 ( 0.044)	Loss 2.0706e-02 (1.1158e-02)	Acc@1 100.00 ( 99.63)
[1000/1091]	Time  0.042 ( 0.044)	Loss 1.7011e-02 (1.2312e-02)	Acc@1 100.00 ( 99.60)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.600
Epoch:  6
[   0/1091]	Time  0.255 ( 0.255)	Loss 4.8645e-03 (4.8645e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.043 ( 0.045)	Loss 7.8462e-02 (1.0925e-02)	Acc@1  95.00 ( 99.60)
[ 200/1091]	Time  0.044 ( 0.044)	Loss 4.8311e-03 (9.4189e-03)	Acc@1 100.00 ( 99.65)
[ 300/1091]	Time  0.043 ( 0.044)	Loss 6.6987e-03 (1.0577e-02)	Acc@1 100.00 ( 99.62)
[ 400/1091]	Time  0.043 ( 0.044)	Loss 5.6467e-05 (1.1687e-02)	Acc@1 100.00 ( 99.59)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 5.1439e-04 (1.3080e-02)	Acc@1 100.00 ( 99.57)
[ 600/1091]	Time  0.043 ( 0.044)	Loss 2.9773e-05 (1.3160e-02)	Acc@1 100.00 ( 99.58)
[ 700/1091]	Time  0.045 ( 0.044)	Loss 1.9180e-04 (1.2855e-02)	Acc@1 100.00 ( 99.59)
[ 800/1091]	Time  0.044 ( 0.044)	Loss 8.7255e-05 (1.3562e-02)	Acc@1 100.00 ( 99.56)
[ 900/1091]	Time  0.040 ( 0.044)	Loss 5.5155e-03 (1.3529e-02)	Acc@1 100.00 ( 99.55)
[1000/1091]	Time  0.042 ( 0.044)	Loss 9.3305e-04 (1.4700e-02)	Acc@1 100.00 ( 99.53)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.200
Epoch:  7
[   0/1091]	Time  0.257 ( 0.257)	Loss 1.2140e-03 (1.2140e-03)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.043 ( 0.045)	Loss 2.5809e-06 (4.7557e-03)	Acc@1 100.00 ( 99.80)
[ 200/1091]	Time  0.043 ( 0.044)	Loss 8.4121e-04 (5.5801e-03)	Acc@1 100.00 ( 99.83)
[ 300/1091]	Time  0.044 ( 0.044)	Loss 4.5780e-05 (5.1677e-03)	Acc@1 100.00 ( 99.85)
[ 400/1091]	Time  0.044 ( 0.044)	Loss 1.2512e-03 (4.5028e-03)	Acc@1 100.00 ( 99.86)
[ 500/1091]	Time  0.044 ( 0.044)	Loss 2.1696e-06 (4.3899e-03)	Acc@1 100.00 ( 99.88)
[ 600/1091]	Time  0.044 ( 0.044)	Loss 8.2659e-05 (4.1196e-03)	Acc@1 100.00 ( 99.88)
[ 700/1091]	Time  0.046 ( 0.044)	Loss 4.6491e-06 (4.5201e-03)	Acc@1 100.00 ( 99.86)
[ 800/1091]	Time  0.044 ( 0.044)	Loss 3.7789e-06 (5.3785e-03)	Acc@1 100.00 ( 99.84)
[ 900/1091]	Time  0.044 ( 0.044)	Loss 5.0313e-05 (6.8424e-03)	Acc@1 100.00 ( 99.79)
[1000/1091]	Time  0.043 ( 0.044)	Loss 2.4302e-04 (8.0052e-03)	Acc@1 100.00 ( 99.75)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.800
Epoch:  8
[   0/1091]	Time  0.253 ( 0.253)	Loss 1.5648e-04 (1.5648e-04)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.044 ( 0.045)	Loss 1.0886e-03 (5.5487e-03)	Acc@1 100.00 ( 99.85)
[ 200/1091]	Time  0.043 ( 0.044)	Loss 6.3645e-05 (3.8627e-03)	Acc@1 100.00 ( 99.90)
[ 300/1091]	Time  0.044 ( 0.044)	Loss 2.3239e-03 (3.7744e-03)	Acc@1 100.00 ( 99.90)
[ 400/1091]	Time  0.044 ( 0.044)	Loss 5.6316e-04 (3.2983e-03)	Acc@1 100.00 ( 99.91)
[ 500/1091]	Time  0.042 ( 0.044)	Loss 3.8400e-05 (3.6069e-03)	Acc@1 100.00 ( 99.89)
[ 600/1091]	Time  0.043 ( 0.044)	Loss 7.1493e-02 (3.8539e-03)	Acc@1  95.00 ( 99.88)
[ 700/1091]	Time  0.044 ( 0.044)	Loss 4.4943e-04 (3.4334e-03)	Acc@1 100.00 ( 99.90)
[ 800/1091]	Time  0.044 ( 0.044)	Loss 6.7409e-06 (3.0434e-03)	Acc@1 100.00 ( 99.91)
[ 900/1091]	Time  0.042 ( 0.044)	Loss 3.4373e-04 (3.4362e-03)	Acc@1 100.00 ( 99.89)
[1000/1091]	Time  0.044 ( 0.044)	Loss 2.3087e-04 (4.1418e-03)	Acc@1 100.00 ( 99.88)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 100.000
Epoch:  9
[   0/1091]	Time  0.258 ( 0.258)	Loss 2.6329e-04 (2.6329e-04)	Acc@1 100.00 (100.00)
[ 100/1091]	Time  0.040 ( 0.045)	Loss 9.9950e-05 (8.3957e-03)	Acc@1 100.00 ( 99.75)
[ 200/1091]	Time  0.039 ( 0.045)	Loss 1.0490e-06 (7.6670e-03)	Acc@1 100.00 ( 99.78)
[ 300/1091]	Time  0.042 ( 0.044)	Loss 1.0004e-03 (1.2978e-02)	Acc@1 100.00 ( 99.63)
[ 400/1091]	Time  0.043 ( 0.044)	Loss 6.5297e-04 (1.0959e-02)	Acc@1 100.00 ( 99.69)
[ 500/1091]	Time  0.043 ( 0.044)	Loss 1.5865e-05 (9.2654e-03)	Acc@1 100.00 ( 99.73)
[ 600/1091]	Time  0.044 ( 0.044)	Loss 1.3202e-04 (8.4265e-03)	Acc@1 100.00 ( 99.75)
[ 700/1091]	Time  0.044 ( 0.044)	Loss 2.5028e-03 (7.3998e-03)	Acc@1 100.00 ( 99.78)
[ 800/1091]	Time  0.043 ( 0.044)	Loss 5.5669e-06 (7.8045e-03)	Acc@1 100.00 ( 99.78)
[ 900/1091]	Time  0.044 ( 0.044)	Loss 7.8614e-05 (7.3023e-03)	Acc@1 100.00 ( 99.79)
[1000/1091]	Time  0.044 ( 0.044)	Loss 7.3809e-05 (6.9125e-03)	Acc@1 100.00 ( 99.80)


  0%|          | 0/25 [00:00<?, ?it/s]

 * Acc@1 99.600


In [31]:
test_loader = torch.utils.data.DataLoader(
    XFDataset(test_path, [0] * len(test_path), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

val_label = pd.DataFrame()
val_label['img'] = [x.split('/')[-1] for x in test_path]
val_label['label'] = predict(test_loader, model, 1).argmax(1)
val_label.to_csv('submit.csv', index=None)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):


  0%|          | 0/112 [00:00<?, ?it/s]

In [30]:
val_label

Unnamed: 0,image,label
0,001.jpg,0
1,002.jpg,1
2,003.jpg,3
3,004.jpg,4
4,005.jpg,0
...,...,...
4458,995.jpg,3
4459,996.jpg,1
4460,997.jpg,0
4461,998.jpg,0
