In [1]:
# pip install ttach

In [11]:
import torch
import torch.nn as nn
import ttach as tta

from utils.dataloader import CustomImageFolder, get_transform
from utils.utils import accuracy
from utils.utils import seed_everything

from model.mymodel import ConvMixer
from tqdm import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [14]:
CONFIG = {'seed' : 42,
          'batch_size' : 64,
          'epochs' : 300,
          'img_size' : 128,
          'train_path' : './data/Food_dataset/train',
          'valid_path' : './data/Food_dataset/val',
          'test_path' : './data/Food_dataset/val',
          'save_path' : './saved/models',
          'learning_rate' : 1e-3,
          'nclasses' : 100,
          'dim' : 1024,
          'depth' : 7,
          'kernel' : 3,
          'patch' : 3,
          'cutmix_prob' : 0.3
          }

In [15]:
seed_everything(CONFIG['seed'])

In [16]:
test_dataset = CustomImageFolder(root = CONFIG['test_path'],
                            transform=get_transform(train_mode='valid'))
    
    
test_dataloader = torch.utils.data.DataLoader(
                        test_dataset,
                        batch_size=CONFIG['batch_size'], shuffle=True)

In [5]:
transforms = tta.Compose([
    tta.HorizontalFlip(),
    tta.VerticalFlip(),
    ])

In [17]:
model = ConvMixer(dim=CONFIG['dim'], depth=CONFIG['depth'], kernel_size=CONFIG['kernel'], patch_size=CONFIG['patch'], n_classes=CONFIG['nclasses']).to(device)
model_state_dict = torch.load('./saved/models/best_model_0927.pt')
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [21]:
tta_model = tta.ClassificationTTAWrapper(model, transforms, merge_mode='sum').to(device)

In [22]:
len(test_dataloader.dataset)

4000

In [23]:
def evaluate(modeln: nn.Module, dataloader):
    accsum = 0
    
    modeln.eval()
    with torch.no_grad(): 
        for X, y in tqdm(dataloader): 
            X, y = X.to(device), y.to(device)
            yhat = modeln(X)
            acc = accuracy(y.cpu().data.numpy(), yhat.cpu().data.numpy().argmax(-1))
            accsum += (acc * len(y) / len(dataloader.dataset)) 

    return accsum

In [20]:
evaluate(model, test_dataloader)

100% 63/63 [00:28<00:00,  2.24it/s]


0.8977499999999999

In [24]:
evaluate(tta_model, test_dataloader)

100% 63/63 [01:03<00:00,  1.01s/it]


0.8942500000000001

In [81]:
evaluate(tta_model, test_dataloader)

100% 63/63 [00:38<00:00,  1.62it/s]


0.8967500000000002

In [12]:
def evaluate2(model: nn.Module, dataloader):
    accsum = 0

    model.eval()
    
    l1 = []
    with torch.no_grad(): 
        for X, y in tqdm(dataloader): 
            X, y = X.to(device), y.to(device)
            yhat = model(X)
            acc = accuracy(y.cpu().data.numpy(), yhat.cpu().data.numpy().argmax(-1))
            accsum += (acc * len(y) / len(dataloader.dataset))
            for i,j in zip(y.cpu().data.numpy(), yhat.cpu().data.numpy().argmax(-1)):
                if i != j:
                    l1.append((i,j))

    return accsum, l1

In [45]:
a,b = evaluate2(tta_model, test_dataloader)

100% 63/63 [00:40<00:00,  1.56it/s]


In [59]:
new = [tuple(sorted(i)) for i in b]

In [60]:
new

[(57, 86),
 (52, 98),
 (78, 89),
 (27, 81),
 (28, 58),
 (16, 45),
 (17, 79),
 (65, 96),
 (59, 72),
 (54, 67),
 (33, 56),
 (19, 65),
 (51, 99),
 (27, 81),
 (61, 68),
 (7, 89),
 (24, 78),
 (73, 91),
 (41, 92),
 (8, 47),
 (9, 84),
 (21, 59),
 (9, 84),
 (18, 71),
 (59, 86),
 (21, 53),
 (4, 51),
 (28, 91),
 (9, 84),
 (18, 49),
 (65, 67),
 (78, 84),
 (18, 98),
 (11, 83),
 (11, 83),
 (97, 98),
 (22, 41),
 (18, 49),
 (91, 94),
 (18, 98),
 (45, 57),
 (61, 85),
 (18, 57),
 (31, 53),
 (0, 39),
 (61, 68),
 (61, 68),
 (81, 84),
 (41, 56),
 (49, 73),
 (28, 67),
 (78, 89),
 (16, 74),
 (53, 58),
 (65, 77),
 (89, 92),
 (1, 86),
 (30, 59),
 (72, 95),
 (21, 27),
 (23, 77),
 (91, 93),
 (4, 92),
 (20, 61),
 (2, 99),
 (19, 73),
 (61, 68),
 (16, 48),
 (46, 99),
 (11, 83),
 (14, 28),
 (18, 97),
 (2, 52),
 (38, 98),
 (15, 28),
 (9, 84),
 (8, 67),
 (54, 94),
 (32, 54),
 (34, 67),
 (3, 36),
 (39, 65),
 (11, 83),
 (49, 99),
 (69, 86),
 (65, 77),
 (46, 78),
 (58, 91),
 (32, 58),
 (33, 87),
 (47, 69),
 (2, 24),
 (6

In [51]:
from collections import Counter

In [70]:
k = Counter(b)

In [69]:
sorted(k.items(), key = lambda x : x[1],reverse=True)

[((61, 68), 14),
 ((84, 9), 9),
 ((68, 61), 6),
 ((83, 11), 6),
 ((78, 24), 5),
 ((9, 84), 5),
 ((98, 18), 5),
 ((84, 81), 5),
 ((21, 59), 4),
 ((11, 83), 4),
 ((77, 65), 4),
 ((91, 94), 3),
 ((18, 98), 3),
 ((81, 84), 3),
 ((93, 91), 3),
 ((73, 19), 3),
 ((89, 93), 3),
 ((18, 17), 3),
 ((48, 16), 3),
 ((92, 59), 3),
 ((92, 87), 3),
 ((89, 78), 2),
 ((58, 28), 2),
 ((45, 16), 2),
 ((19, 65), 2),
 ((81, 27), 2),
 ((53, 21), 2),
 ((4, 51), 2),
 ((28, 67), 2),
 ((99, 2), 2),
 ((28, 14), 2),
 ((2, 52), 2),
 ((36, 3), 2),
 ((78, 46), 2),
 ((2, 24), 2),
 ((89, 24), 2),
 ((43, 69), 2),
 ((84, 80), 2),
 ((78, 89), 2),
 ((49, 38), 2),
 ((57, 14), 2),
 ((47, 69), 2),
 ((2, 27), 2),
 ((94, 91), 2),
 ((21, 1), 2),
 ((70, 19), 2),
 ((92, 21), 2),
 ((9, 46), 2),
 ((86, 30), 2),
 ((73, 1), 2),
 ((88, 57), 2),
 ((2, 37), 2),
 ((54, 56), 2),
 ((99, 82), 2),
 ((91, 92), 2),
 ((38, 49), 2),
 ((32, 53), 2),
 ((98, 24), 2),
 ((96, 19), 2),
 ((19, 96), 2),
 ((57, 86), 1),
 ((98, 52), 1),
 ((27, 81), 1),
 ((

In [71]:
c = Counter(new)

In [67]:
sorted(c.items(), key = lambda x : x[1],reverse=True)

[((61, 68), 20),
 ((9, 84), 14),
 ((11, 83), 10),
 ((18, 98), 8),
 ((81, 84), 8),
 ((24, 78), 5),
 ((91, 94), 5),
 ((65, 77), 5),
 ((78, 89), 4),
 ((21, 59), 4),
 ((16, 48), 4),
 ((89, 93), 4),
 ((38, 49), 4),
 ((19, 96), 4),
 ((27, 81), 3),
 ((91, 93), 3),
 ((19, 73), 3),
 ((47, 69), 3),
 ((21, 92), 3),
 ((43, 69), 3),
 ((14, 57), 3),
 ((17, 18), 3),
 ((30, 86), 3),
 ((59, 92), 3),
 ((87, 92), 3),
 ((28, 58), 2),
 ((16, 45), 2),
 ((19, 65), 2),
 ((21, 53), 2),
 ((4, 51), 2),
 ((28, 91), 2),
 ((18, 49), 2),
 ((22, 41), 2),
 ((45, 57), 2),
 ((28, 67), 2),
 ((2, 99), 2),
 ((14, 28), 2),
 ((2, 52), 2),
 ((3, 36), 2),
 ((39, 65), 2),
 ((46, 78), 2),
 ((2, 24), 2),
 ((24, 89), 2),
 ((80, 84), 2),
 ((2, 27), 2),
 ((22, 75), 2),
 ((1, 21), 2),
 ((19, 70), 2),
 ((9, 46), 2),
 ((1, 73), 2),
 ((57, 88), 2),
 ((2, 37), 2),
 ((54, 56), 2),
 ((82, 99), 2),
 ((91, 92), 2),
 ((29, 51), 2),
 ((91, 97), 2),
 ((32, 53), 2),
 ((24, 98), 2),
 ((57, 86), 1),
 ((52, 98), 1),
 ((17, 79), 1),
 ((65, 96), 1),


In [66]:
test_dataloader.dataset.class_to_idx

{'간자장': 0,
 '갈치구이': 1,
 '감자구이': 2,
 '감자밥': 3,
 '경단': 4,
 '고구마맛탕': 5,
 '골드키위': 6,
 '골뱅이국수무침': 7,
 '김치우동': 8,
 '까망베르치즈': 9,
 '깨죽': 10,
 '깻잎나물볶음': 11,
 '낙지탕': 12,
 '느타리버섯볶음': 13,
 '달걀찜': 14,
 '달걀후라이': 15,
 '닭가슴살샐러드': 16,
 '닭강정': 17,
 '닭날개': 18,
 '더덕구이': 19,
 '도가니탕': 20,
 '도넛': 21,
 '동태조림': 22,
 '돼지고기메추리알장조림': 23,
 '땅콩버터': 24,
 '랍스타': 25,
 '마늘쫑무침': 26,
 '망고': 27,
 '메쉬드포테이토': 28,
 '모듬회': 29,
 '무국': 30,
 '무지개떡': 31,
 '뮤즐리': 32,
 '미역국': 33,
 '미역초무침': 34,
 '바지락조개국': 35,
 '밥': 36,
 '번데기': 37,
 '베이컨구이': 38,
 '볶음우동': 39,
 '부추전': 40,
 '북어조림': 41,
 '브로콜리': 42,
 '뼈해장국': 43,
 '산딸기': 44,
 '살사소스': 45,
 '삶은달걀': 46,
 '삼선짬뽕': 47,
 '새싹샐러드': 48,
 '새우': 49,
 '새우튀김롤': 50,
 '생선초밥': 51,
 '생선튀김': 52,
 '설탕': 53,
 '쇠고기볶음': 54,
 '수수부꾸미': 55,
 '순대볶음': 56,
 '스프': 57,
 '시리얼바': 58,
 '식혜': 59,
 '아몬드': 60,
 '액상요구르트': 61,
 '양배추': 62,
 '어죽': 63,
 '오이': 64,
 '오징어채볶음': 65,
 '와사비': 66,
 '우동': 67,
 '우유': 68,
 '육개장': 69,
 '육회': 70,
 '잡곡식빵': 71,
 '잣': 72,
 '장어구이': 73,
 '전주비빔밥': 74,
 '조기조림': 75,
 '조기찜': 76,
 '쥐치채': 77,
 '쨈빵': 78,

In [13]:
a,b = evaluate2(model, test_dataloader)

100% 63/63 [00:27<00:00,  2.30it/s]


In [14]:
from collections import Counter

new = [tuple(sorted(i)) for i in b]
k = Counter(b)
sorted(k.items(), key = lambda x : x[1],reverse=True)

[((61, 68), 15),
 ((84, 9), 10),
 ((84, 81), 7),
 ((18, 98), 6),
 ((68, 61), 5),
 ((77, 65), 4),
 ((11, 83), 4),
 ((91, 94), 4),
 ((89, 93), 4),
 ((9, 84), 4),
 ((2, 52), 3),
 ((83, 11), 3),
 ((21, 59), 3),
 ((93, 78), 3),
 ((41, 75), 3),
 ((98, 18), 3),
 ((99, 18), 3),
 ((92, 59), 3),
 ((78, 89), 3),
 ((19, 96), 3),
 ((92, 87), 3),
 ((18, 17), 2),
 ((89, 78), 2),
 ((9, 68), 2),
 ((93, 92), 2),
 ((65, 39), 2),
 ((58, 78), 2),
 ((86, 1), 2),
 ((43, 69), 2),
 ((99, 49), 2),
 ((67, 8), 2),
 ((93, 89), 2),
 ((96, 90), 2),
 ((8, 47), 2),
 ((18, 70), 2),
 ((2, 5), 2),
 ((73, 1), 2),
 ((73, 19), 2),
 ((39, 74), 2),
 ((45, 4), 2),
 ((2, 37), 2),
 ((70, 19), 2),
 ((21, 51), 2),
 ((92, 71), 2),
 ((78, 91), 2),
 ((4, 59), 2),
 ((93, 91), 2),
 ((61, 15), 2),
 ((46, 15), 2),
 ((48, 16), 2),
 ((19, 65), 2),
 ((46, 84), 2),
 ((65, 77), 2),
 ((34, 83), 2),
 ((87, 91), 2),
 ((49, 73), 2),
 ((51, 29), 2),
 ((22, 41), 2),
 ((4, 27), 2),
 ((47, 69), 2),
 ((2, 28), 1),
 ((97, 71), 1),
 ((56, 41), 1),
 ((22

In [15]:
a

0.8942500000000002

In [17]:
len(b)

423

In [10]:
import yaml
import torch
import torch.nn as nn
import ttach as tta
from utils.dataloader import CustomImageFolder, get_transform
from utils.utils import accuracy
from model.mymodel import ConvMixer
from utils.utils import seed_everything

from tqdm import tqdm


def main(model_name):

    config = yaml.load(open('./config/' + str(model_name) + '.yaml', 'r'), Loader=yaml.FullLoader)

    seed_everything(7)

    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

    test_dataset = CustomImageFolder(root = config['test_path'],
                            transform=get_transform(train_mode='valid'))
    
    test_dataloader = torch.utils.data.DataLoader(
                            test_dataset,
                            batch_size=config['batch_size'], shuffle=True)

    transforms = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        ])    

    model = ConvMixer(dim=config['dim'], depth=config['depth'], kernel_size=config['kernel'], patch_size=config['patch'], n_classes=config['nclasses']).to(device)
    model_state_dict = torch.load('./saved/models/best_model_0928.pt')
    model.load_state_dict(model_state_dict)

    tta_model = tta.ClassificationTTAWrapper(model, transforms, merge_mode='sum').to(device)

    accsum = 0

    tta_model.eval()

    with torch.no_grad(): 
        for X, y in tqdm(test_dataloader): 
            X, y = X.to(device), y.to(device)
            yhat = tta_model(X)
            acc = accuracy(y.cpu().data.numpy(), yhat.cpu().data.numpy().argmax(-1))
            accsum += (acc * len(y) / len(test_dataloader.dataset)) 

    print("Accuracy of test images: {:.4f}".format(accsum))

if __name__ == "__main__":
    main("mymodel")

100% 63/63 [00:47<00:00,  1.32it/s]

Accuracy of test images: 0.8867





In [9]:
import yaml
import torch
import torch.nn as nn
import ttach as tta
from utils.dataloader import CustomImageFolder, get_transform
from utils.utils import accuracy
from model.mymodel import ConvMixer
from utils.utils import seed_everything

from tqdm import tqdm


def main(model_name):

    config = yaml.load(open('./config/' + str(model_name) + '.yaml', 'r'), Loader=yaml.FullLoader)

    seed_everything(7)

    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

    test_dataset = CustomImageFolder(root = config['test_path'],
                            transform=get_transform(train_mode='valid'))
    
    test_dataloader = torch.utils.data.DataLoader(
                            test_dataset,
                            batch_size=config['batch_size'], shuffle=True)

    transforms = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        ])    

    model = ConvMixer(dim=config['dim'], depth=config['depth'], kernel_size=config['kernel'], patch_size=config['patch'], n_classes=config['nclasses']).to(device)
    model_state_dict = torch.load('./saved/models/best_model_0928.pt')
    model.load_state_dict(model_state_dict)

    accsum = 0

    model.eval()

    with torch.no_grad(): 
        for X, y in tqdm(test_dataloader): 
            X, y = X.to(device), y.to(device)
            yhat = model(X)
            acc = accuracy(y.cpu().data.numpy(), yhat.cpu().data.numpy().argmax(-1))
            accsum += (acc * len(y) / len(test_dataloader.dataset)) 

    print("Accuracy of test images: {:.4f}".format(accsum))

if __name__ == "__main__":
    main("mymodel")

100% 63/63 [00:25<00:00,  2.47it/s]

Accuracy of test images: 0.8748



