In [1]:
import os
import torch # type: ignore
import pandas as pd # type: ignore
import torch.nn.functional as F # type: ignore
from torch.utils.data import DataLoader # type: ignore
from tqdm.auto import tqdm # type: ignore
from models.models import *
from src.dataset import *
from src.transform import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "./data/train"    
data_info_file = "./data/train.csv"
save_result_path = "./train_result"
num_classes = 500

In [4]:
data_info = pd.read_csv(data_info_file)
transform_selector = TransformSelector(
    transform_type = "torchvision"
)
transforms = transform_selector.get_transform(is_train=False)
dataset = CustomDataset(
    root_dir=data_dir,
    info_df=data_info,
    transform=transforms,
    is_inference=True
)
data_loader = DataLoader(
    dataset, 
    batch_size=64, 
    shuffle=False,
    drop_last=False
)

In [5]:
# 모델 추론을 위한 함수
def inference(
    model: nn.Module, 
    device: torch.device, 
    test_loader: DataLoader
):
    # 모델을 평가 모드로 설정
    model.to(device)
    model.eval()
    
    predictions = []
    with torch.no_grad():  # Gradient 계산을 비활성화
        for images in tqdm(test_loader):
            # 데이터를 같은 장치로 이동
            images = images.to(device)
            
            # 모델을 통해 예측 수행
            logits = model(images)
            logits = F.softmax(logits, dim=1)
            preds = logits.argmax(dim=1)
            
            # 예측 결과 저장
            predictions.extend(preds.cpu().detach().numpy())  # 결과를 CPU로 옮기고 리스트에 추가     
    
    return predictions

In [6]:
model_selector = ModelSelector(
    model_type='timm', 
    num_classes=num_classes,
    model_name='convnext_base', 
    pretrained=True
)
model = model_selector.get_model()
model.load_state_dict(
    torch.load(
        os.path.join(save_result_path, "convnext_base(lr0.00015)", "best_model.pt"),
        map_location='cpu'
    )
)
predictions = inference(
    model=model, 
    device=device, 
    test_loader=data_loader
)

100%|██████████| 235/235 [03:10<00:00,  1.23it/s]


In [7]:
mini_imagenet_cls_map = {'n01872401': 'echidna, spiny anteater, anteater',
 'n02417914': 'ibex, Capra ibex',
 'n02106166': 'Border collie',
 'n04235860': 'sleeping bag',
 'n02056570': 'king penguin, Aptenodytes patagonica',
 'n07734744': 'mushroom',
 'n02098286': 'West Highland white terrier',
 'n02097298': 'Scotch terrier, Scottish terrier, Scottie',
 'n02403003': 'ox',
 'n04456115': 'torch',
 'n02408429': 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
 'n09472597': 'volcano',
 'n04004767': 'printer',
 'n03832673': 'notebook, notebook computer',
 'n01748264': 'Indian cobra, Naja naja',
 'n02096437': 'Dandie Dinmont, Dandie Dinmont terrier',
 'n02325366': 'wood rabbit, cottontail, cottontail rabbit',
 'n03857828': 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
 'n03481172': 'hammer',
 'n02701002': 'ambulance',
 'n01855032': 'red-breasted merganser, Mergus serrator',
 'n01698640': 'American alligator, Alligator mississipiensis',
 'n02114548': 'white wolf, Arctic wolf, Canis lupus tundrarum',
 'n01644900': 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
 'n02107574': 'Greater Swiss Mountain dog',
 'n03803284': 'muzzle',
 'n02494079': 'squirrel monkey, Saimiri sciureus',
 'n02027492': 'red-backed sandpiper, dunlin, Erolia alpina',
 'n04296562': 'stage',
 'n03584829': 'iron, smoothing iron',
 'n01843065': 'jacamar',
 'n03530642': 'honeycomb',
 'n02791124': 'barber chair',
 'n04486054': 'triumphal arch',
 'n01744401': 'rock python, rock snake, Python sebae',
 'n03063689': 'coffeepot',
 'n02110958': 'pug, pug-dog',
 'n04507155': 'umbrella',
 'n03710193': 'mailbox, letter box',
 'n01580077': 'jay',
 'n13052670': 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
 'n02279972': 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
 'n04336792': 'stretcher',
 'n02108915': 'French bulldog',
 'n04517823': 'vacuum, vacuum cleaner',
 'n07753592': 'banana',
 'n02992211': 'cello, violoncello',
 'n01531178': 'goldfinch, Carduelis carduelis',
 'n02396427': 'wild boar, boar, Sus scrofa',
 'n03444034': 'go-kart',
 'n01614925': 'bald eagle, American eagle, Haliaeetus leucocephalus',
 'n04039381': 'racket, racquet',
 'n03888605': 'parallel bars, bars',
 'n03425413': 'gas pump, gasoline pump, petrol pump, island dispenser',
 'n03895866': 'passenger car, coach, carriage',
 'n12998815': 'agaric',
 'n02087394': 'Rhodesian ridgeback',
 'n02097209': 'standard schnauzer',
 'n04259630': 'sombrero',
 'n03445777': 'golf ball',
 'n04040759': 'radiator',
 'n02454379': 'armadillo',
 'n02971356': 'carton',
 'n03929660': 'pick, plectrum, plectron',
 'n02690373': 'airliner',
 'n01774384': 'black widow, Latrodectus mactans',
 'n03134739': 'croquet ball',
 'n02085782': 'Japanese spaniel',
 'n04404412': 'television, television system',
 'n01514668': 'cock',
 'n04525305': 'vending machine',
 'n04560804': 'water jug',
 'n03642806': 'laptop, laptop computer',
 'n02422699': 'impala, Aepyceros melampus',
 'n01985128': 'crayfish, crawfish, crawdad, crawdaddy',
 'n04344873': 'studio couch, day bed',
 'n07716906': 'spaghetti squash',
 'n02951585': 'can opener, tin opener',
 'n03874599': 'padlock',
 'n01753488': 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
 'n02643566': 'lionfish',
 'n04081281': 'restaurant, eating house, eating place, eatery',
 'n02110806': 'basenji',
 'n02009912': 'American egret, great white heron, Egretta albus',
 'n01494475': 'hammerhead, hammerhead shark',
 'n02445715': 'skunk, polecat, wood pussy',
 'n10565667': 'scuba diver',
 'n03355925': 'flagpole, flagstaff',
 'n04204347': 'shopping cart',
 'n04591157': 'Windsor tie',
 'n03781244': 'monastery',
 'n04026417': 'purse',
 'n09288635': 'geyser',
 'n02113624': 'toy poodle',
 'n02113023': 'Pembroke, Pembroke Welsh corgi',
 'n01843383': 'toucan',
 'n04141076': 'sax, saxophone',
 'n03345487': 'fire engine, fire truck',
 'n01983481': 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
 'n01950731': 'sea slug, nudibranch',
 'n02092339': 'Weimaraner',
 'n01729322': 'hognose snake, puff adder, sand viper',
 'n03131574': 'crib, cot',
 'n04606251': 'wreck',
 'n02102177': 'Welsh springer spaniel',
 'n01616318': 'vulture',
 'n04350905': 'suit, suit of clothes',
 'n01532829': 'house finch, linnet, Carpodacus mexicanus',
 'n02321529': 'sea cucumber, holothurian',
 'n01601694': 'water ouzel, dipper',
 'n04127249': 'safety pin',
 'n03598930': 'jigsaw puzzle',
 'n02837789': 'bikini, two-piece',
 'n09193705': 'alp',
 'n02364673': 'guinea pig, Cavia cobaya',
 'n03063599': 'coffee mug',
 'n07248320': 'book jacket, dust cover, dust jacket, dust wrapper',
 'n03109150': 'corkscrew, bottle screw',
 'n01688243': 'frilled lizard, Chlamydosaurus kingi',
 'n02002556': 'white stork, Ciconia ciconia',
 'n03770439': 'miniskirt, mini',
 'n04310018': 'steam locomotive',
 'n03445924': 'golfcart, golf cart',
 'n02101388': 'Brittany spaniel',
 'n04037443': 'racer, race car, racing car',
 'n04409515': 'tennis ball',
 'n03188531': 'diaper, nappy, napkin',
 'n04120489': 'running shoe',
 'n01806567': 'quail',
 'n02109047': 'Great Dane',
 'n04131690': 'saltshaker, salt shaker',
 'n01873310': 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
 'n04540053': 'volleyball',
 'n02085620': 'Chihuahua',
 'n04275548': "spider web, spider's web",
 'n02641379': 'gar, garfish, garpike, billfish, Lepisosteus osseus',
 'n02268443': "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
 'n01534433': 'junco, snowbird',
 'n02105056': 'groenendael',
 'n02769748': 'backpack, back pack, knapsack, packsack, rucksack, haversack',
 'n01774750': 'tarantula',
 'n02410509': 'bison',
 'n03763968': 'military uniform',
 'n03691459': 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
 'n01882714': 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
 'n09256479': 'coral reef',
 'n02109525': 'Saint Bernard, St Bernard',
 'n02105412': 'kelpie',
 'n02437312': 'Arabian camel, dromedary, Camelus dromedarius',
 'n02939185': 'caldron, cauldron',
 'n03788365': 'mosquito net',
 'n03673027': 'liner, ocean liner',
 'n02105855': 'Shetland sheepdog, Shetland sheep dog, Shetland',
 'n03291819': 'envelope',
 'n01749939': 'green mamba',
 'n04325704': 'stole',
 'n02415577': 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
 'n03376595': 'folding chair',
 'n02793495': 'barn',
 'n02018795': 'bustard',
 'n04553703': 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
 'n04317175': 'stethoscope',
 'n04380533': 'table lamp',
 'n01833805': 'hummingbird',
 'n02895154': 'breastplate, aegis, egis',
 'n02389026': 'sorrel',
 'n03942813': 'ping-pong ball',
 'n02094114': 'Norfolk terrier',
 'n02129604': 'tiger, Panthera tigris',
 'n04428191': 'thresher, thrasher, threshing machine',
 'n03916031': 'perfume, essence',
 'n04326547': 'stone wall',
 'n02892201': 'brass, memorial tablet, plaque',
 'n04005630': 'prison, prison house',
 'n07717410': 'acorn squash',
 'n04252225': 'snowplow, snowplough',
 'n04554684': 'washer, automatic washer, washing machine',
 'n02443484': 'black-footed ferret, ferret, Mustela nigripes',
 'n01818515': 'macaw',
 'n02672831': 'accordion, piano accordion, squeeze box',
 'n03773504': 'missile',
 'n02113799': 'standard poodle',
 'n02100735': 'English setter',
 'n02104029': 'kuvasz',
 'n04370456': 'sweatshirt',
 'n07579787': 'plate',
 'n04447861': 'toilet seat',
 'n03759954': 'microphone, mike',
 'n02112350': 'keeshond',
 'n04162706': 'seat belt, seatbelt',
 'n02100236': 'German short-haired pointer',
 'n01751748': 'sea snake',
 'n02133161': 'American black bear, black bear, Ursus americanus, Euarctos americanus',
 'n02236044': 'mantis, mantid',
 'n01592084': 'chickadee',
 'n04604644': 'worm fence, snake fence, snake-rail fence, Virginia fence',
 'n02011460': 'bittern',
 'n02097047': 'miniature schnauzer',
 'n12768682': 'buckeye, horse chestnut, conker',
 'n03871628': 'packet',
 'n07716358': 'zucchini, courgette',
 'n03933933': 'pier',
 'n04479046': 'trench coat',
 'n03089624': 'confectionery, confectionary, candy store',
 'n03980874': 'poncho',
 'n03127925': 'crate',
 'n04209239': 'shower curtain',
 'n04136333': 'sarong',
 'n02443114': 'polecat, fitch, foulmart, foumart, Mustela putorius',
 'n04505470': 'typewriter keyboard',
 'n02094258': 'Norwich terrier',
 'n04229816': 'ski mask',
 'n02089867': 'Walker hound, Walker foxhound',
 'n02092002': 'Scottish deerhound, deerhound',
 'n07920052': 'espresso',
 'n03347037': 'fire screen, fireguard',
 'n02099267': 'flat-coated retriever',
 'n02802426': 'basketball',
 'n01739381': 'vine snake',
 'n03544143': 'hourglass',
 'n02096585': 'Boston bull, Boston terrier',
 'n02105162': 'malinois',
 'n04493381': 'tub, vat',
 'n02708093': 'analog clock',
 'n03891251': 'park bench',
 'n04208210': 'shovel',
 'n04033901': 'quill, quill pen',
 'n02088632': 'bluetick',
 'n04550184': 'wardrobe, closet, press',
 'n04596742': 'wok',
 'n03976657': 'pole',
 'n02091134': 'whippet',
 'n07754684': 'jackfruit, jak, jack',
 'n01978287': 'Dungeness crab, Cancer magister',
 'n02727426': 'apiary, bee house',
 'n02115913': 'dhole, Cuon alpinus',
 'n04263257': 'soup bowl',
 'n03476684': 'hair slide',
 'n04476259': 'tray',
 'n02077923': 'sea lion',
 'n03794056': 'mousetrap',
 'n01728572': 'thunder snake, worm snake, Carphophis amoenus',
 'n04579432': 'whistle',
 'n02992529': 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
 'n13054560': 'bolete',
 'n07860988': 'dough',
 'n03207941': 'dishwasher, dish washer, dishwashing machine',
 'n07615774': 'ice lolly, lolly, lollipop, popsicle',
 'n02102480': 'Sussex spaniel',
 'n01910747': 'jellyfish',
 'n04532106': 'vestment',
 'n04266014': 'space shuttle',
 'n04125021': 'safe',
 'n01530575': 'brambling, Fringilla montifringilla',
 'n04599235': 'wool, woolen, woollen',
 'n04435653': 'tile roof',
 'n02097130': 'giant schnauzer',
 'n02114712': 'red wolf, maned wolf, Canis rufus, Canis niger',
 'n03271574': 'electric fan, blower',
 'n03724870': 'mask',
 'n02096177': 'cairn, cairn terrier',
 'n04509417': 'unicycle, monocycle',
 'n01828970': 'bee eater',
 'n02100583': 'vizsla, Hungarian pointer',
 'n03496892': 'harvester, reaper',
 'n04209133': 'shower cap',
 'n02917067': 'bullet train, bullet',
 'n02906734': 'broom',
 'n02749479': 'assault rifle, assault gun',
 'n01860187': 'black swan, Cygnus atratus',
 'n01537544': 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
 'n03637318': 'lampshade, lamp shade',
 'n02132136': 'brown bear, bruin, Ursus arctos',
 'n02088094': 'Afghan hound, Afghan',
 'n03379051': 'football helmet',
 'n04201297': 'shoji',
 'n01855672': 'goose',
 'n01632777': 'axolotl, mud puppy, Ambystoma mexicanum',
 'n03249569': 'drum, membranophone, tympan',
 'n04487394': 'trombone',
 'n02892767': 'brassiere, bra, bandeau',
 'n04146614': 'school bus',
 'n02441942': 'weasel',
 'n07873807': 'pizza, pizza pie',
 'n02091467': 'Norwegian elkhound, elkhound',
 'n01807496': 'partridge',
 'n02808440': 'bathtub, bathing tub, bath, tub',
 'n02088238': 'basset, basset hound',
 'n02110185': 'Siberian husky',
 'n01641577': 'bullfrog, Rana catesbeiana',
 'n01770393': 'scorpion',
 'n02090379': 'redbone',
 'n02457408': 'three-toed sloth, ai, Bradypus tridactylus',
 'n04141975': 'scale, weighing machine',
 'n01795545': 'black grouse',
 'n03958227': 'plastic bag',
 'n04485082': 'tripod',
 'n04019541': 'puck, hockey puck',
 'n02094433': 'Yorkshire terrier',
 'n04355338': 'sundial',
 'n01695060': 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
 'n03956157': 'planetarium',
 'n01697457': 'African crocodile, Nile crocodile, Crocodylus niloticus',
 'n02172182': 'dung beetle',
 'n03388043': 'fountain',
 'n01518878': 'ostrich, Struthio camelus',
 'n03995372': 'power drill',
 'n04589890': 'window screen',
 'n04254777': 'sock',
 'n04584207': 'wig',
 'n04591713': 'wine bottle',
 'n04118776': 'rule, ruler',
 'n02091032': 'Italian greyhound',
 'n04429376': 'throne',
 'n02493793': 'spider monkey, Ateles geoffroyi',
 'n02999410': 'chain',
 'n10148035': 'groom, bridegroom',
 'n03124043': 'cowboy boot',
 'n02687172': 'aircraft carrier, carrier, flattop, attack aircraft carrier',
 'n02493509': 'titi, titi monkey',
 'n01775062': 'wolf spider, hunting spider',
 'n04192698': 'shield, buckler',
 'n02058221': 'albatross, mollymawk',
 'n03595614': 'jersey, T-shirt, tee shirt',
 'n04523525': 'vault',
 'n03814906': 'necklace',
 'n02412080': 'ram, tup',
 'n09835506': 'ballplayer, baseball player',
 'n04074963': 'remote control, remote',
 'n02099601': 'golden retriever',
 'n02100877': 'Irish setter, red setter',
 'n01740131': 'night snake, Hypsiglena torquata',
 'n01608432': 'kite',
 'n02037110': 'oystercatcher, oyster catcher',
 'n02488702': 'colobus, colobus monkey',
 'n02108089': 'boxer',
 'n04376876': 'syringe',
 'n02814533': 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
 'n07742313': 'Granny Smith',
 'n02111277': 'Newfoundland, Newfoundland dog',
 'n03692522': "loupe, jeweler's loupe",
 'n07718747': 'artichoke, globe artichoke',
 'n04522168': 'vase',
 'n04049303': 'rain barrel',
 'n02106550': 'Rottweiler',
 'n04418357': 'theater curtain, theatre curtain',
 'n02088364': 'beagle',
 'n03676483': 'lipstick, lip rouge',
 'n04372370': 'switch, electric switch, electrical switch',
 'n02795169': 'barrel, cask',
 'n02510455': 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
 'n02974003': 'car wheel',
 'n02281787': 'lycaenid, lycaenid butterfly',
 'n01677366': 'common iguana, iguana, Iguana iguana',
 'n01773797': 'garden spider, Aranea diademata',
 'n02123045': 'tabby, tabby cat',
 'n01968897': 'chambered nautilus, pearly nautilus, nautilus',
 'n04465501': 'tractor',
 'n03903868': 'pedestal, plinth, footstall',
 'n02799071': 'baseball',
 'n06785654': 'crossword puzzle, crossword',
 'n02699494': 'altar',
 'n02129165': 'lion, king of beasts, Panthera leo',
 'n07583066': 'guacamole',
 'n03775071': 'mitten',
 'n03761084': 'microwave, microwave oven',
 'n02814860': 'beacon, lighthouse, beacon light, pharos',
 'n03649909': 'lawn mower, mower',
 'n02093256': 'Staffordshire bullterrier, Staffordshire bull terrier',
 'n02130308': 'cheetah, chetah, Acinonyx jubatus',
 'n02102040': 'English springer, English springer spaniel',
 'n04536866': 'violin, fiddle',
 'n02730930': 'apron',
 'n11939491': 'daisy',
 'n02487347': 'macaque',
 'n02102318': 'cocker spaniel, English cocker spaniel, cocker',
 'n07714990': 'broccoli',
 'n04612504': 'yawl',
 'n03998194': 'prayer rug, prayer mat',
 'n01770081': 'harvestman, daddy longlegs, Phalangium opilio',
 'n02134084': 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
 'n02927161': 'butcher shop, meat market',
 'n04069434': 'reflex camera',
 'n03787032': 'mortarboard',
 'n02111889': 'Samoyed, Samoyede',
 'n03793489': 'mouse, computer mouse',
 'n07875152': 'potpie',
 'n02009229': 'little blue heron, Egretta caerulea',
 'n02086240': 'Shih-Tzu',
 'n07715103': 'cauliflower',
 'n02676566': 'acoustic guitar',
 'n04443257': 'tobacco shop, tobacconist shop, tobacconist',
 'n04483307': 'trimaran',
 'n03630383': 'lab coat, laboratory coat',
 'n02326432': 'hare',
 'n02124075': 'Egyptian cat',
 'n02280649': 'cabbage butterfly',
 'n02361337': 'marmot',
 'n02692877': 'airship, dirigible',
 'n04557648': 'water bottle',
 'n12267677': 'acorn',
 'n02165456': 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
 'n02112137': 'chow, chow chow',
 'n01914609': 'sea anemone, anemone',
 'n01735189': 'garter snake, grass snake',
 'n01944390': 'snail',
 'n03498962': 'hatchet',
 'n04356056': 'sunglasses, dark glasses, shades',
 'n02089973': 'English foxhound',
 'n02123597': 'Siamese cat, Siamese',
 'n04254680': 'soccer ball',
 'n02111500': 'Great Pyrenees',
 'n03394916': 'French horn, horn',
 'n07745940': 'strawberry',
 'n03709823': 'mailbag, postbag',
 'n07614500': 'ice cream, icecream',
 'n01704323': 'triceratops',
 'n03792782': 'mountain bike, all-terrain bike, off-roader',
 'n02883205': 'bow tie, bow-tie, bowtie',
 'n02101006': 'Gordon setter',
 'n01644373': 'tree frog, tree-frog',
 'n04366367': 'suspension bridge',
 'n02879718': 'bow',
 'n12144580': 'corn',
 'n07613480': 'trifle',
 'n04371430': 'swimming trunks, bathing trunks',
 'n03534580': 'hoopskirt, crinoline',
 'n03450230': 'gown',
 'n03938244': 'pillow',
 'n01443537': 'goldfish, Carassius auratus',
 'n02086079': 'Pekinese, Pekingese, Peke',
 'n02391049': 'zebra',
 'n02398521': 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
 'n02093428': 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
 'n02107142': 'Doberman, Doberman pinscher',
 'n03208938': 'disk brake, disc brake',
 'n04399382': 'teddy, teddy bear',
 'n04118538': 'rugby ball',
 'n04355933': 'sunglass',
 'n04330267': 'stove',
 'n03045698': 'cloak',
 'n02128385': 'leopard, Panthera pardus',
 'n04442312': 'toaster',
 'n02356798': 'fox squirrel, eastern fox squirrel, Sciurus niger',
 'n03970156': "plunger, plumber's helper",
 'n02107683': 'Bernese mountain dog',
 'n03982430': 'pool table, billiard table, snooker table',
 'n02342885': 'hamster',
 'n09246464': 'cliff, drop, drop-off',
 'n02114367': 'timber wolf, grey wolf, gray wolf, Canis lupus',
 'n07747607': 'orange',
 'n02113712': 'miniature poodle',
 'n04487081': 'trolleybus, trolley coach, trackless trolley',
 'n04238763': 'slide rule, slipstick',
 'n02088466': 'bloodhound, sleuthhound',
 'n01630670': 'common newt, Triturus vulgaris',
 'n02117135': 'hyena, hyaena',
 'n04501370': 'turnstile',
 'n02098413': 'Lhasa, Lhasa apso',
 'n02127052': 'lynx, catamount',
 'n02782093': 'balloon',
 'n04548362': 'wallet, billfold, notecase, pocketbook',
 'n01496331': 'electric ray, crampfish, numbfish, torpedo',
 'n02841315': 'binoculars, field glasses, opera glasses',
 'n02480495': 'orangutan, orang, orangutang, Pongo pygmaeus',
 'n03961711': 'plate rack',
 'n02106030': 'collie',
 'n02397096': 'warthog',
 'n04041544': 'radio, wireless',
 'n02865351': 'bolo tie, bolo, bola tie, bola',
 'n03297495': 'espresso maker',
 'n01742172': 'boa constrictor, Constrictor constrictor',
 'n04252077': 'snowmobile',
 'n02087046': 'toy terrier',
 'n07768694': 'pomegranate',
 'n01582220': 'magpie',
 'n01930112': 'nematode, nematode worm, roundworm',
 'n01558993': 'robin, American robin, Turdus migratorius',
 'n02442845': 'mink',
 'n02105641': 'Old English sheepdog, bobtail',
 'n02120079': 'Arctic fox, white fox, Alopex lagopus',
 'n02085936': 'Maltese dog, Maltese terrier, Maltese',
 'n01694178': 'African chameleon, Chamaeleo chamaeleon',
 'n04371774': 'swing',
 'n03447721': 'gong, tam-tam',
 'n04251144': 'snorkel',
 'n01484850': 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
 'n02007558': 'flamingo',
 'n07753113': 'fig',
 'n02277742': 'ringlet, ringlet butterfly',
 'n03720891': 'maraca',
 'n02051845': 'pelican',
 'n02090622': 'borzoi, Russian wolfhound',
 'n02091831': 'Saluki, gazelle hound',
 'n02843684': 'birdhouse',
 'n04357314': 'sunscreen, sunblock, sun blocker',
 'n02437616': 'llama',
 'n02108422': 'bull mastiff',
 'n03729826': 'matchstick',
 'n02870880': 'bookcase'}

In [8]:
data_info

Unnamed: 0,class_name,image_path,target
0,n01872401,n01872401/sketch_50.JPEG,59
1,n02417914,n02417914/sketch_11.JPEG,202
2,n02106166,n02106166/sketch_3.JPEG,138
3,n04235860,n04235860/sketch_2.JPEG,382
4,n02056570,n02056570/sketch_40.JPEG,80
...,...,...,...
15016,n02108089,n02108089/sketch_32.JPEG,143
15017,n02129604,n02129604/sketch_7.JPEG,172
15018,n07920052,n07920052/sketch_26.JPEG,484
15019,n02325366,n02325366/sketch_46.JPEG,186


In [9]:
data_info_copy = data_info.copy()
data_info_copy['prediction'] = predictions
target_map_list = []
prediction_map_list = []
for i in range(len(data_info_copy)):
    target = data_info_copy.iloc[i, :]['target']
    prediction = data_info_copy.iloc[i, :]['prediction']
    target_map = mini_imagenet_cls_map[data_info_copy.iloc[i, :]['class_name']]
    if target == prediction:
        prediction_map = target_map
    else:
        prediction_class = data_info_copy[data_info_copy['target'] == prediction]['class_name'].unique()[0]
        prediction_map = mini_imagenet_cls_map[prediction_class]
    target_map_list.append(target_map)
    prediction_map_list.append(prediction_map)

data_info_copy['target_map'] = target_map_list
data_info_copy['prediction_map'] = prediction_map_list
data_info_copy['check'] = [False if data_info_copy.iloc[i, :]['target'] == data_info_copy.iloc[i, :]['prediction'] else True for i in range(len(data_info_copy))]
data_info_copy

Unnamed: 0,class_name,image_path,target,prediction,target_map,prediction_map,check
0,n01872401,n01872401/sketch_50.JPEG,59,59,"echidna, spiny anteater, anteater","echidna, spiny anteater, anteater",False
1,n02417914,n02417914/sketch_11.JPEG,202,202,"ibex, Capra ibex","ibex, Capra ibex",False
2,n02106166,n02106166/sketch_3.JPEG,138,138,Border collie,Border collie,False
3,n04235860,n04235860/sketch_2.JPEG,382,382,sleeping bag,sleeping bag,False
4,n02056570,n02056570/sketch_40.JPEG,80,80,"king penguin, Aptenodytes patagonica","king penguin, Aptenodytes patagonica",False
...,...,...,...,...,...,...,...
15016,n02108089,n02108089/sketch_32.JPEG,143,143,boxer,boxer,False
15017,n02129604,n02129604/sketch_7.JPEG,172,172,"tiger, Panthera tigris","tiger, Panthera tigris",False
15018,n07920052,n07920052/sketch_26.JPEG,484,484,espresso,espresso,False
15019,n02325366,n02325366/sketch_46.JPEG,186,186,"wood rabbit, cottontail, cottontail rabbit","wood rabbit, cottontail, cottontail rabbit",False


In [11]:
from collections import defaultdict

class_name_list = os.listdir(data_dir)
error_count = defaultdict(int)

for i in range(len(class_name_list)):
    if class_name_list[i] != '.DS_Store':
        class_name = class_name_list[i]
        class_data_info = data_info_copy[data_info_copy['class_name'] == class_name]
        error_count[class_name] = class_data_info['check'].sum()

In [18]:
for k, v in error_count.items():
    if v >= 3:
        print(k)

n02091032
n02088632
n04493381
n04355933
n01739381
n01532829
n07734744
n01530575
n02443114
n01833805
n02279972
n02101388
n02442845
n01983481
n02102318
n02107574
n02009912
n01728572
n02892201
n04380533
n02109047
n02281787
n02100877
n02111889
n02412080
n02091467
n02799071
n02115913
n03109150
n02106030
n02091134
n01729322
n01697457
n03297495
n04356056
n02097130
n01698640
n02096177
n04418357
n01641577
n02113799
n02085620
n04296562
n02089867
n02097209
n02808440


In [21]:
df = pd.DataFrame()
for k, v in error_count.items():
    if v >= 3:
        class_name = k
        class_data_info = data_info_copy[data_info_copy['class_name'] == class_name]
        new_df = class_data_info[class_data_info['check'] == True]
        df = pd.concat([df, new_df], ignore_index=True)

df.to_csv('error_list.csv')

In [34]:
canny1 = os.listdir('./data/train_canny')
canny2 = os.listdir('./data/train_canny2')

for i in range(len(canny1)):
    canny1_folder_imgs_dir = os.path.join('./data/train_canny', canny1[i])
    canny2_folder_imgs_dir = os.path.join('./data/train_canny2', canny2[i])
    canny1_folder_imgs = os.listdir(canny1_folder_imgs_dir)
    canny2_folder_imgs = os.listdir(canny2_folder_imgs_dir)
    for j in range(len(canny1_folder_imgs)):
        if canny1_folder_imgs[j] != canny2_folder_imgs[j]:
            print(f'{canny1[i]} {canny1_folder_imgs[j]}')

In [42]:
import matplotlib.pyplot as plt # type: ignore
from torchcam.methods import GradCAM # type: ignore

def visualize_gradcam(
        model: torch.nn.Module,
        device: torch.device,
        dataloader: DataLoader,
        target_layer: str,
        image_index: int
    ):

    cam_extractor = GradCAM(model, target_layer)
    model.eval()  
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    

    # 데이터 로더에서 배치를 반복합니다.

    current_index = 0
    for inputs in dataloader:
        inputs = inputs.to(device)  # 입력 이미지를 장치로 이동합니다.
        outputs = model(inputs)  # 모델을 통해 예측을 수행합니다.
        _, preds = torch.max(outputs, 1)  # 예측된 클래스 인덱스를 가져옵니다.    

        # 배치 내의 각 이미지에 대해 처리합니다.
        for j in range(inputs.size()[0]):
            if current_index == image_index:
                # CAM을 가져옵니다.
                cam = cam_extractor(preds[j].item(), outputs[j].unsqueeze(0))[0]
                # CAM을 1채널로 변환합니다.
                cam = cam.mean(dim=0).cpu().numpy()
                # CAM을 원본 이미지 크기로 리사이즈합니다.
                cam = cv2.resize(cam, (inputs[j].shape[2], inputs[j].shape[1]))
                cam = (cam - cam.min()) / (cam.max() - cam.min())  # 정규화           
                cam = np.uint8(255 * cam)
                cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
                cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)  # BGR에서 RGB로 변환
                input_image = inputs[j].cpu().numpy().transpose((1, 2, 0))

                if input_image.shape[2] == 1:  # 1채널 이미지인 경우

                    input_image = np.squeeze(input_image, axis=2)  # (H, W, 1) -> (H, W)

                    input_image = np.stack([input_image] * 3, axis=-1)  # (H, W) -> (H, W, 3)로 변환하여 RGB처럼 만듭니다.

                else:  # 3채널 이미지인 경우

                    input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

                    input_image = (input_image * 255).astype(np.uint8)  # 정규화된 이미지를 8비트 이미지로 변환합니다.

                

                # 오리지널 이미지

                axes[0].imshow(input_image)

                axes[0].set_title("Original Image")

                axes[0].axis('off')

                

                # Grad-CAM 이미지

                axes[1].imshow(cam)

                axes[1].set_title("Grad-CAM Image")

                axes[1].axis('off')

                

                # 오버레이된 이미지 생성

                overlay = cv2.addWeighted(input_image, 0.5, cam, 0.5, 0)

                axes[2].imshow(overlay)

                axes[2].set_title("Overlay Image")

                axes[2].axis('off')

                

                plt.show()  # 시각화를 표시합니다.

                return

            current_index += 1