In [1]:
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
import numpy as np
from PIL import Image
import os

In [2]:
from geo_data import create_anchor_transform
from libs.ConvNeXt.models.convnext import ConvNeXtFeature
from models.Distancer import GeoDiscriminator


In [3]:
anchor_samples = [ 
    "AD_42.528,1.56927.png", "AL_41.32654,19.82209.png", "AT_47.73333,14.21667.png", "BA_43.91194,18.08083.png", "BE_50.78263,4.5334.png", "BG_42.71231,25.3329.png", "BY_53.0245,26.3403.png", "CH_46.90981,8.11206.png", "CY_35.119479999999996,33.28853.png", "CZ_49.73456,15.29297.png", "DE_50.39996,9.98198.png", "DK_55.80849,10.581669999999999.png", "EE_58.63053,25.55402.png", "ES_39.68888,-3.50281.png", "FI_61.929730000000006,25.15144.png", "FR_46.91745,2.49814.png", "GB_52.81773,-1.76009.png", "GR_37.97451,23.51769.png", "HR_44.655,15.95083.png", "HU_47.25,19.06667.png", "IE_53.32528000000001,-7.979439999999999.png", "IS_64.13267,-20.30651.png", "IT_43.43218,11.77323.png", "LI_47.17556,9.57287.png", "LT_55.41019,23.7299.png", "LU_49.64506,6.12932.png", "LV_57.0619,24.84465.png", "MC_43.74041,7.42311.png", "MD_47.01095,28.85176.png", "ME_42.39333,18.89028.png", "MK_41.63468,21.40268.png", "MT_35.94556,14.38972.png", "NL_52.1738,5.48497.png", "NO_62.20631,10.63725.png", "PL_51.85225,19.59197.png", "PT_39.66978,-8.9958.png", "RO_45.68811,24.97548.png", "RS_44.24947,20.39613.png", "RU_54.1766,37.8881.png", "SE_59.06565,15.337470000000001.png", "SI_46.05804,14.82515.png", "SK_48.56315,19.3029.png", "SM_43.90867,12.44808.png", "UA_48.57325,29.71874.png", "VA_41.90394,12.45401.png", "XK_42.54018,20.28793.png",
]
device = torch.device("cuda")

In [4]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trans = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

In [9]:
backbone = ConvNeXtFeature(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], num_classes=46)

# checkpoint_model = torch.load("convnext_base_22k_224.pth", map_location='cpu')['model']
checkpoint_model = torch.load("./logs/cls_final/checkpoint-best.pth", map_location='cpu')['model']
for k in ['head.weight', 'head.bias']:
    del checkpoint_model[k]
backbone.load_state_dict(checkpoint_model, strict=False)
                         
model = GeoDiscriminator(1024)
# model.load_state_dict(torch.load("logs/contrast_final/checkpoint-4.pth", map_location='cpu')['model'], strict=False)
model.load_state_dict(torch.load("logs/dis_final/checkpoint-6.pth", map_location='cpu')['model'], strict=False)

backbone.to(device)
model.to(device)

backbone.eval()
model.eval()

GeoDiscriminator(
  (head): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): GELU(approximate=none)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): GELU(approximate=none)
    (4): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [10]:
anchor_images = []
anchor_coords = []
data_path = "./data"

for name in anchor_samples:
    _, coord = name[:-4].split("_")
    lat, lng = coord.split(",")
    latlng = np.array([float(lat), float(lng)])
    img_path = os.path.join(data_path, name)
    img = Image.open(img_path)
    anchor_images.append(trans(img))
    anchor_coords.append(torch.Tensor(latlng))

anchor_images = torch.stack(anchor_images, 0).to(device)
anchor_coords = torch.stack(anchor_coords, 0).to(device)
with torch.no_grad():
    anchor_features = backbone(anchor_images)

In [29]:
@torch.no_grad()
def pred_distances(img_path):
    _, coord = img_path[:-4].split("_")
    lat, lng = coord.split(",")
    latlng = np.array([float(lat), float(lng)])
    
    img = Image.open(os.path.join(data_path, img_path))
    img_tensor = trans(img).to(device)
    img_tensor = img_tensor.unsqueeze(0)
    
    features = backbone(img_tensor)
    
    img_coord = torch.Tensor(latlng).unsqueeze(0).to(device)
    geo_distance = torch.pairwise_distance(anchor_coords, img_coord.repeat(anchor_coords.shape[0], 1), p=2, keepdim=True).clip(0, 10)

    feature_distance = model(torch.cat([anchor_features, features.repeat(anchor_features.shape[0], 1)], dim=-1))
    feature_distance = torch.sigmoid(feature_distance)*10
    
#     print((geo_distance - feature_distance).abs().mean().item()/46)
    fea_dis, geo_dis = feature_distance.cpu().numpy(), geo_distance.cpu().numpy()
    fea_dis, geo_dis = fea_dis[:,0], geo_dis[:,0]
    
    gt = geo_dis.argmin()
    top5 = fea_dis.argsort()[:5]
    
    correct = (gt==top5[0])
    correct5 = False
    for t in top5:
        if t==gt:
            correct5 = True
            break
    
    return correct, correct5, (geo_distance-feature_distance).mean().cpu().numpy()

In [30]:
data_path = "./eval"

correct1, correct5 = [], []

errors = []
for fname in os.listdir("./eval"):
    c1, c5, error = pred_distances(fname)
    correct1.append(c1)
    correct5.append(c5)
    errors.append(error)


In [31]:
np.array(errors).mean()

3.3535607

In [32]:
np.array(correct1).sum()

11

In [33]:
np.array(correct5).sum()

36

In [28]:
def load_img(path):
    _, coord = path[:-4].split("_")
    lat, lng = coord.split(",")
    latlng = np.array([float(lat), float(lng)])
    
    img = Image.open(os.path.join(data_path, path))
    img_tensor = trans(img).to(device)
    img_tensor = img_tensor.unsqueeze(0)
    return img_tensor, latlng

@torch.no_grad()
def pred_pair(img1, img2):
    img_tensor1, latlng1 = load_img(img1)
    img_tensor2, latlng2 = load_img(img2)
    
    feature1 = backbone(img_tensor1)
    feature2 = backbone(img_tensor2)
    
    img_coord1 = torch.Tensor(latlng1).unsqueeze(0).to(device)
    img_coord2 = torch.Tensor(latlng2).unsqueeze(0).to(device)
    geo_distance = torch.pairwise_distance(img_coord1, img_coord2, p=2, keepdim=True)

    feature_distance = model(torch.cat([feature1, feature2], dim=-1))
    feature_distance = torch.sigmoid(feature_distance)*50
    
#     print((geo_distance - feature_distance).abs().mean().item()/46)
    fea_dis, geo_dis = feature_distance.cpu().numpy(), geo_distance.cpu().numpy()
    fea_dis, geo_dis = fea_dis[:,0], geo_dis[:,0]
    
#     for t in range(len(fea_dis)):
#         print(f"{geo_dis[t]:.3f}  ----  {fea_dis[t]:.3f}")
#     print(anchor_coords[top5].mean(0).cpu().numpy())
#     print(latlng)
    print(np.linalg.norm(anchor_coords.mean(0).cpu().numpy() - latlng)/46)
    
    return fea_dis, geo_dis, latlng

In [29]:
data_path = "./data"
pred_pair("AD_42.46395,1.5126.png", "AD_42.46295,1.50926.png")

0.1741782672329096


(array([0.], dtype=float32),
 array([0.00348747], dtype=float32),
 array([42.54018, 20.28793]))