In [6]:
import sys
sys.path.append("training")
import torch
import warnings
torch.autograd.set_detect_anomaly(True)
warnings.simplefilter("ignore")
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
import json
import argparse
import training.timm
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm
import pandas as pd
import pathlib


from training.utils.config_utils import load_yaml
from training.vis_utils import ImgLoader

In [7]:
def build_model(pretrainewd_path: str,
                img_size: int,
                fpn_size: int,
                num_classes: int,
                num_selects: dict,
                use_fpn: bool = True,
                use_selection: bool = True,
                use_combiner: bool = True,
                comb_proj_size: int = None):
    from training.models.pim_module.pim_module_eval import PluginMoodel

    model = \
        PluginMoodel(img_size = img_size,
                     use_fpn = use_fpn,
                     fpn_size = fpn_size,
                     proj_type = "Linear",
                     upsample_type = "Conv",
                     use_selection = use_selection,
                     num_classes = num_classes,
                     num_selects = num_selects,
                     use_combiner = use_combiner,
                     comb_proj_size = comb_proj_size)

    if pretrainewd_path != "":
        ckpt = torch.load(pretrainewd_path)
        model.load_state_dict(ckpt['model'])

    model.eval()

    return model

@torch.no_grad()
def sum_all_out(out, sum_type="softmax", use_combiner=True):
    target_layer_names = \
    ['layer1', 'layer2', 'layer3', 'layer4',
    'FPN1_layer1', 'FPN1_layer2', 'FPN1_layer3', 'FPN1_layer4'
    ]

    if use_combiner:
        target_layer_names.append('comb_outs')

    sum_out = None
    for name in target_layer_names:
        if name != "comb_outs":
            tmp_out = out[name].mean(1)
        else:
            tmp_out = out[name]

        if sum_type == "softmax":
            tmp_out = torch.softmax(tmp_out, dim=-1)
        if sum_out is None:
            sum_out = tmp_out
        else:
            sum_out = sum_out + tmp_out # note that use '+=' would cause inplace error
    return sum_out

In [None]:
args = argparse.Namespace()
load_yaml(args, "training/config.yaml")

# ===== 1. build model =====
model = build_model(pretrainewd_path =  "best.pt",
                    img_size = args.data_size,
                    fpn_size = args.fpn_size,
                    num_classes = args.num_classes,
                    num_selects = args.num_selects,
                    use_combiner = args.use_combiner,
                    use_selection = args.use_selection
                    )
model.cuda()
# print(args.data_size)
img_loader = ImgLoader(img_size = args.data_size)
root = pathlib.Path("training/datasets/test")
files = list(root.rglob("*.jpg"))
# files = list(root.rglob("*/*.jpg"))
files = [str(f) for f in files]
# print(files)
files.sort()
imgs = []
names = []
logits = []
label_name = [
    "001.Black_footed_Albatross",
    "002.Laysan_Albatross",
    "003.Sooty_Albatross",
    "004.Groove_billed_Ani",
    "005.Crested_Auklet",
    "006.Least_Auklet",
    "007.Parakeet_Auklet",
    "008.Rhinoceros_Auklet",
    "009.Brewer_Blackbird",
    "010.Red_winged_Blackbird",
    "011.Rusty_Blackbird",
    "012.Yellow_headed_Blackbird",
    "013.Bobolink",
    "014.Indigo_Bunting",
    "015.Lazuli_Bunting",
    "016.Painted_Bunting",
    "017.Cardinal",
    "018.Spotted_Catbird",
    "019.Gray_Catbird",
    "020.Yellow_breasted_Chat",
    "021.Eastern_Towhee",
    "022.Chuck_will_Widow",
    "023.Brandt_Cormorant",
    "024.Red_faced_Cormorant",
    "025.Pelagic_Cormorant",
    "026.Bronzed_Cowbird",
    "027.Shiny_Cowbird",
    "028.Brown_Creeper",
    "029.American_Crow",
    "030.Fish_Crow",
    "031.Black_billed_Cuckoo",
    "032.Mangrove_Cuckoo",
    "033.Yellow_billed_Cuckoo",
    "034.Gray_crowned_Rosy_Finch",
    "035.Purple_Finch",
    "036.Northern_Flicker",
    "037.Acadian_Flycatcher",
    "038.Great_Crested_Flycatcher",
    "039.Least_Flycatcher",
    "040.Olive_sided_Flycatcher",
    "041.Scissor_tailed_Flycatcher",
    "042.Vermilion_Flycatcher",
    "043.Yellow_bellied_Flycatcher",
    "044.Frigatebird",
    "045.Northern_Fulmar",
    "046.Gadwall",
    "047.American_Goldfinch",
    "048.European_Goldfinch",
    "049.Boat_tailed_Grackle",
    "050.Eared_Grebe",
    "051.Horned_Grebe",
    "052.Pied_billed_Grebe",
    "053.Western_Grebe",
    "054.Blue_Grosbeak",
    "055.Evening_Grosbeak",
    "056.Pine_Grosbeak",
    "057.Rose_breasted_Grosbeak",
    "058.Pigeon_Guillemot",
    "059.California_Gull",
    "060.Glaucous_winged_Gull",
    "061.Heermann_Gull",
    "062.Herring_Gull",
    "063.Ivory_Gull",
    "064.Ring_billed_Gull",
    "065.Slaty_backed_Gull",
    "066.Western_Gull",
    "067.Anna_Hummingbird",
    "068.Ruby_throated_Hummingbird",
    "069.Rufous_Hummingbird",
    "070.Green_Violetear",
    "071.Long_tailed_Jaeger",
    "072.Pomarine_Jaeger",
    "073.Blue_Jay",
    "074.Florida_Jay",
    "075.Green_Jay",
    "076.Dark_eyed_Junco",
    "077.Tropical_Kingbird",
    "078.Gray_Kingbird",
    "079.Belted_Kingfisher",
    "080.Green_Kingfisher",
    "081.Pied_Kingfisher",
    "082.Ringed_Kingfisher",
    "083.White_breasted_Kingfisher",
    "084.Red_legged_Kittiwake",
    "085.Horned_Lark",
    "086.Pacific_Loon",
    "087.Mallard",
    "088.Western_Meadowlark",
    "089.Hooded_Merganser",
    "090.Red_breasted_Merganser",
    "091.Mockingbird",
    "092.Nighthawk",
    "093.Clark_Nutcracker",
    "094.White_breasted_Nuthatch",
    "095.Baltimore_Oriole",
    "096.Hooded_Oriole",
    "097.Orchard_Oriole",
    "098.Scott_Oriole",
    "099.Ovenbird",
    "100.Brown_Pelican",
    "101.White_Pelican",
    "102.Western_Wood_Pewee",
    "103.Sayornis",
    "104.American_Pipit",
    "105.Whip_poor_Will",
    "106.Horned_Puffin",
    "107.Common_Raven",
    "108.White_necked_Raven",
    "109.American_Redstart",
    "110.Geococcyx",
    "111.Loggerhead_Shrike",
    "112.Great_Grey_Shrike",
    "113.Baird_Sparrow",
    "114.Black_throated_Sparrow",
    "115.Brewer_Sparrow",
    "116.Chipping_Sparrow",
    "117.Clay_colored_Sparrow",
    "118.House_Sparrow",
    "119.Field_Sparrow",
    "120.Fox_Sparrow",
    "121.Grasshopper_Sparrow",
    "122.Harris_Sparrow",
    "123.Henslow_Sparrow",
    "124.Le_Conte_Sparrow",
    "125.Lincoln_Sparrow",
    "126.Nelson_Sharp_tailed_Sparrow",
    "127.Savannah_Sparrow",
    "128.Seaside_Sparrow",
    "129.Song_Sparrow",
    "130.Tree_Sparrow",
    "131.Vesper_Sparrow",
    "132.White_crowned_Sparrow",
    "133.White_throated_Sparrow",
    "134.Cape_Glossy_Starling",
    "135.Bank_Swallow",
    "136.Barn_Swallow",
    "137.Cliff_Swallow",
    "138.Tree_Swallow",
    "139.Scarlet_Tanager",
    "140.Summer_Tanager",
    "141.Artic_Tern",
    "142.Black_Tern",
    "143.Caspian_Tern",
    "144.Common_Tern",
    "145.Elegant_Tern",
    "146.Forsters_Tern",
    "147.Least_Tern",
    "148.Green_tailed_Towhee",
    "149.Brown_Thrasher",
    "150.Sage_Thrasher",
    "151.Black_capped_Vireo",
    "152.Blue_headed_Vireo",
    "153.Philadelphia_Vireo",
    "154.Red_eyed_Vireo",
    "155.Warbling_Vireo",
    "156.White_eyed_Vireo",
    "157.Yellow_throated_Vireo",
    "158.Bay_breasted_Warbler",
    "159.Black_and_white_Warbler",
    "160.Black_throated_Blue_Warbler",
    "161.Blue_winged_Warbler",
    "162.Canada_Warbler",
    "163.Cape_May_Warbler",
    "164.Cerulean_Warbler",
    "165.Chestnut_sided_Warbler",
    "166.Golden_winged_Warbler",
    "167.Hooded_Warbler",
    "168.Kentucky_Warbler",
    "169.Magnolia_Warbler",
    "170.Mourning_Warbler",
    "171.Myrtle_Warbler",
    "172.Nashville_Warbler",
    "173.Orange_crowned_Warbler",
    "174.Palm_Warbler",
    "175.Pine_Warbler",
    "176.Prairie_Warbler",
    "177.Prothonotary_Warbler",
    "178.Swainson_Warbler",
    "179.Tennessee_Warbler",
    "180.Wilson_Warbler",
    "181.Worm_eating_Warbler",
    "182.Yellow_Warbler",
    "183.Northern_Waterthrush",
    "184.Louisiana_Waterthrush",
    "185.Bohemian_Waxwing",
    "186.Cedar_Waxwing",
    "187.American_Three_toed_Woodpecker",
    "188.Pileated_Woodpecker",
    "189.Red_bellied_Woodpecker",
    "190.Red_cockaded_Woodpecker",
    "191.Red_headed_Woodpecker",
    "192.Downy_Woodpecker",
    "193.Bewick_Wren",
    "194.Cactus_Wren",
    "195.Carolina_Wren",
    "196.House_Wren",
    "197.Marsh_Wren",
    "198.Rock_Wren",
    "199.Winter_Wren",
    "200.Common_Yellowthroat",
]
res = {
    "id": [],
    "label": []
}

for fi, f in enumerate(files):
    # print(args.image_root + "/" + str(f))
    # print(f)
    img, ori_img = img_loader.load(f)
    img = img.unsqueeze(0)
    imgs.append(img)
    names.append(f.split("/")[-1].replace(".jpg", ""))
    if (fi+1) % 32 == 0 or fi == len(files) - 1:
        imgs = torch.cat(imgs, dim=0)
    else:
        continue
    with torch.no_grad():
        imgs = imgs.cuda()
        outs = model(imgs)
        sum_outs = sum_all_out(outs, sum_type="softmax", use_combiner=args.use_combiner) # softmax
        # because image are five-crop, so we need to average the result
        logits.append(sum_outs.data)
        _, preds = torch.max(sum_outs.data, 1)
        res["id"].extend(names)
        res["label"].extend([label_name[i] for i in preds.cpu().numpy()])
        imgs = []
        names = []

df = pd.DataFrame(res)
df.to_csv("submission.csv", index=False)