In [1]:
import cv2
import pandas as pd
import numpy as np
import json
import jieba
import glob
import torch
from tqdm import tqdm

In [2]:
# https://github.com/boyan01/ChinaRegionDistrict/blob/master/region.json
# https://lbsyun.baidu.com/jsdemo/demo/yLngLatLocation.htm

region_data = json.load(open('region.json'))
districts_data = sum([x['districts'] for x in region_data['districts']], [])
city_data = [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in districts_data]
city_data += [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in region_data['districts']]
city_data = pd.DataFrame(city_data)

city_data

Unnamed: 0,0,1,2
0,花王堂区,113.548961,22.199207
1,望德堂区,113.550183,22.193721
2,大堂区,113.553647,22.188539
3,风顺堂区,113.541928,22.187368
4,花地玛堂区,113.552896,22.207870
...,...,...,...
424,四川省,104.075809,30.651239
425,西藏自治区,91.117525,29.647535
426,新疆维吾尔自治区,87.627704,43.793026
427,云南省,102.710002,25.045806


In [3]:
city_data[0]

0          花王堂区
1          望德堂区
2           大堂区
3          风顺堂区
4         花地玛堂区
         ...   
424         四川省
425       西藏自治区
426    新疆维吾尔自治区
427         云南省
428         浙江省
Name: 0, Length: 429, dtype: object

In [4]:
city_data[city_data[0] == '上海市']

Unnamed: 0,0,1,2
406,上海市,121.473662,31.230372


In [5]:
train_img_locations = pd.read_csv('训练集/图片中心经纬度.txt', sep=',', header=None)
train_img_locations[0] = train_img_locations[0].apply(lambda x: float(str(x).split(' ')[0]))
train_img_locations[1] = train_img_locations[1].apply(lambda x: float(str(x).split(' ')[0]))

test_img_locations = pd.read_csv('初赛测试集/图片中心经纬度.txt', sep=',', header=None)
test_img_locations[0] = test_img_locations[0].apply(lambda x: float(str(x).split(' ')[0]))
test_img_locations[1] = test_img_locations[1].apply(lambda x: float(str(x).split(' ')[0]))

In [6]:
train_imgs = glob.glob('./训练集/图片/*')
train_imgs.sort()
train_imgs = np.array(train_imgs)

test_imgs = glob.glob('./初赛测试集/图片/*')
test_imgs.sort()
test_imgs = np.array(test_imgs)

In [7]:
from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=40, algorithm='ball_tree').fit(test_img_locations)

In [8]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from PIL import Image
import requests
from transformers import ChineseCLIPProcessor, ChineseCLIPModel

model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
processor = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")



In [9]:
questions = open('./数据集更新/初赛测试集/问题.txt').readlines()
results = []
for question in tqdm(questions):
    words = jieba.lcut(question)
    words = [x for x in words if len(x) > 1 and not x.isdigit()]
    city = words[0]
    city_pic_dis, city_pic_index = nbrs.kneighbors([city_data[city_data[0].apply(lambda x: x == city)].values[0][1:]])
    city_pic_dis = city_pic_dis[0]
    city_pic_index = city_pic_index[0]

    with torch.no_grad():
        # compute image feature
        inputs = processor(images=[Image.open(x) for x in test_imgs[city_pic_index]], return_tensors="pt")
        image_features = model.get_image_features(**inputs)
        image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)  # normalize    
    
        # compute text features
        inputs = processor(text=[question], padding=True, return_tensors="pt")
        text_features = model.get_text_features(**inputs)
        text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)  # normalize
    
        ids = torch.matmul(text_features, image_features.T)[0]
        ids = ids.data.cpu().numpy()

        result = city_pic_index[ids.argsort()[::-1]][:5]
        result = [str(x) for x in result]
        result = ','.join(result)
        results.append(result)

  0%|          | 0/49 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.617 seconds.
Prefix dict has been built successfully.
100%|██████████| 49/49 [02:29<00:00,  3.05s/it]


In [10]:
with open('res.csv', 'w') as up:
    for result in results:
        up.write(result + '\n')

In [11]:
ids

array([0.41078353, 0.42506945, 0.39543888, 0.41330916, 0.382412  ,
       0.39154297, 0.3923232 , 0.39476427, 0.40610224, 0.41434634,
       0.40460896, 0.42052194, 0.40704408, 0.3934614 , 0.40922046,
       0.42150655, 0.41470432, 0.436787  , 0.39174995, 0.38255706,
       0.37405667, 0.38224554, 0.4008456 , 0.39214906, 0.41689512,
       0.4153441 , 0.38414806, 0.45356262, 0.4016958 , 0.43198863,
       0.41783026, 0.42809063, 0.43769678, 0.4302791 , 0.41259   ,
       0.4011598 , 0.39734104, 0.40664643, 0.3797541 , 0.40003294],
      dtype=float32)

In [12]:
ids[0]

0.41078353