In [10]:
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import numpy as np
from PIL import Image
import os
import glob

In [6]:
from geo_data import create_anchor_transform
from libs.ConvNeXt.models.convnext import ConvNeXt
import matplotlib.pyplot as plt
imap = plt.imread("europe.png")
from tqdm import tqdm_notebook
from geo_data import build_geo_dataset

In [7]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trans = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

COORD_REF = np.array((50, 10))

In [32]:
device = torch.device("cuda:0")
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], num_classes=46)

model.load_state_dict(torch.load("logs/cls_final/checkpoint-best.pth", map_location='cpu')['model'], strict=False)

model = model.to(device)
model.eval()

ConvNeXt(
  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm()
    )
    (1): Sequential(
      (0): LayerNorm()
      (1): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm()
      (1): Conv2d(256, 512, kernel_size=(2, 2), stride=(2, 2))
    )
    (3): Sequential(
      (0): LayerNorm()
      (1): Conv2d(512, 1024, kernel_size=(2, 2), stride=(2, 2))
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): Block(
        (dwconv): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
        (norm): LayerNorm()
        (pwconv1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU(approximate=none)
        (pwconv2): Linear(in_features=512, out_features=128, bias=True)
        (drop_path): Identity()
      )
      (1): Block(
        (dwconv): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), p

In [11]:
images = glob.glob('./data/*.png')
all_country = []
for im in images:
    cls, _ = os.path.basename(im)[:-4].split("_")
    if cls not in all_country:
        all_country.append(cls)

In [33]:
@torch.no_grad()
def pred_cls(img_path):
    cls, coord = img_path[:-4].split("_")
    lat, lng = coord.split(",")
    latlng = np.array([float(lat), float(lng)])
    
    img = Image.open(os.path.join(data_path, img_path))
    img_tensor = trans(img).to(device)
    img_tensor = img_tensor.unsqueeze(0)
    
    pred = model(img_tensor)*10
    pred = pred.squeeze()
    return pred.argmax()

In [62]:
data_path = "./data"

test = all_country[0]

pred_fre = {}
count = 0
for c in all_country:
    preds = []
    count = 0
    for fname in os.listdir("./data"):
        if c not in fname:
            continue
        pred = pred_cls(fname)
        preds.append(pred.cpu().numpy())
        count += 1
        if count>300:
            break
    pred_fre[c] = np.array(preds)

In [60]:
from collections import Counter
id_map = {}
for i in range(46):
    id_map[i] = []
for k,v in pred_fre.items():
    candidates = Counter(v).most_common(5)
    for j, cnt in candidates:
        id_map[j].append((k, cnt/len(v)))
for i in range(46):
    id_map[i].sort(key=lambda x:-x[1])

In [63]:
id_map

{0: [('DK', 0.5247524752475248), ('SE', 0.04950495049504951)],
 1: [('NO', 0.9504950495049505),
  ('XK', 0.4),
  ('IS', 0.288135593220339),
  ('FI', 0.24752475247524752),
  ('AT', 0.2079207920792079),
  ('AD', 0.1935483870967742),
  ('BA', 0.14285714285714285),
  ('PL', 0.09900990099009901),
  ('SE', 0.0891089108910891),
  ('SI', 0.0891089108910891),
  ('LT', 0.07920792079207921),
  ('MK', 0.06976744186046512),
  ('IE', 0.06930693069306931),
  ('DK', 0.0594059405940594),
  ('BE', 0.04950495049504951),
  ('DE', 0.04950495049504951),
  ('GB', 0.04950495049504951),
  ('RU', 0.04950495049504951),
  ('FR', 0.039603960396039604),
  ('GR', 0.009900990099009901)],
 2: [('ES', 0.9504950495049505),
  ('MC', 0.625),
  ('PT', 0.4752475247524752),
  ('XK', 0.4),
  ('SM', 0.3333333333333333),
  ('VA', 0.3333333333333333),
  ('CY', 0.3023255813953488),
  ('IT', 0.25742574257425743),
  ('AD', 0.24193548387096775),
  ('MK', 0.20930232558139536),
  ('HR', 0.19801980198019803),
  ('AL', 0.177777777777777