In [None]:

import torch
import numpy as np
import torchvision.transforms as T
import torchvision.models as models
import os
from PIL import Image
import json,requests
import matplotlib.pyplot as plt
%matplotlib inline
use_cuda= torch.cuda.is_available()


In [None]:
import base64,re
def img_to_base64(img_path):
    with open(img_path, 'rb')as read:
        b64 = base64.b64encode(read.read())
    return b64

In [None]:
from socr import ocr
images=[]
paths=[('xx','data')]
img_txt=dict()
if os.path.exists('ocr.txt'):
    for line in open('ocr.txt','r').readlines():
        # print(line)
        if line :
            line=line.strip()
            # print(line)
            [img_path,txt,origin]=line.split('####')
            # img_txt[img_path]=txt
            images.append((img_path,txt,origin))
else:
    writer=open('ocr.txt','w')
    for _path in paths:
        files=os.walk(_path[1])
        print(files)
        for x in files:
            print(x)
            sub_dir=x[0]
            sub_files=x[2]
            for im in sub_files:
                if im.endswith('.DS_Store'):
                    continue
                real_path=os.path.join(sub_dir,fr'{im}')
                
                real_words,origin=ocr(img_to_base64(real_path))
                print(real_path,real_words,origin)
                images.append((real_path,real_words))
                writer.write(f'{real_path}####{real_words}####{origin}\n')
print(len(images))

In [None]:
feature_extractor = torch.nn.Sequential(*list(models.resnext101_32x8d(pretrained=True).children())[:-1])
if use_cuda:
    feature_extractor=feature_extractor.cuda()
op_resize=T.Resize(size=(256,256))
op_norm=T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
   )
convert_tensor = T.ToTensor()
image_index=dict()
import faiss                   # make faiss available
faiss_index=dict()
for image_path in images:
    img = Image.open(fr'{image_path[0]}').convert('RGB')
    img=op_resize(img)
    img=convert_tensor(img).view(1,3,256,256)
    if use_cuda:
        img=img.cuda()
    img=op_norm(img)
    img_feature=feature_extractor(img)
    if image_path[1] not in faiss_index.keys():
        faiss_index[image_path[1]]=faiss.IndexFlatL2(2048)   # build the index
    if image_path[1] not in image_index.keys():
        image_index[image_path[1]]=[]
    faiss_index[image_path[1]].add(img_feature.detach().cpu().view(1,-1).numpy())
    image_index[image_path[1]].append(image_path[0])
print(faiss_index['一年级下册'].ntotal)
print(image_index['一年级下册'])

In [None]:

import random
import torchvision.transforms as T
perspective_transformer = T.RandomPerspective(distortion_scale=0.2, p=1.0)
id=random.randint(0,len(images)-1)
# img1 = Image.open(images[id][0]).convert('RGB')
# real_words=ocr(img_to_base64(images[id][0]))
img_path='1.jpg'
img1 = Image.open(img_path).convert("RGB")
real_words=ocr(img_to_base64(img_path))
real_words=(real_words[0].replace('乐','六').replace('小','八').replace('书册','上册'),real_words[1])
print('输入---',real_words[0],real_words[1])
img1=op_resize(img1)
# img1=perspective_transformer(img1)
img=convert_tensor(img1).view(1,3,256,256)
if use_cuda:
    img=img.cuda()
img=op_norm(img)
img_feature=feature_extractor(img).detach().cpu().view(1,-1).numpy()
print(f'img_feature={img_feature.shape}')
NN=10
D, I = faiss_index[real_words[0]].search(img_feature, NN)
pred_ids=[I[0][i] for i in range(NN)]
print(pred_ids,D,I)
plt.imshow(img1)
def get_concat_h(im1,img):
    height= im1.height if im1.height > img.height else img.height
    dst = Image.new('RGB', (im1.width + img.width, height))
    dst.paste(im1, (0, 0))
    dst.paste(img, (im1.width, 0))
    return dst
dst=Image.open(image_index[real_words[0]][pred_ids[0]]).convert("RGB")

for i in range(1,NN):
    img = Image.open(image_index[real_words[0]][pred_ids[i]]).convert("RGB")
    dst=get_concat_h(dst,img)
plt.imshow(dst)
plt.show()
plt.imshow(img1)
plt.show()

In [None]:
import random
perspective_transformer = T.RandomPerspective(distortion_scale=0.2, p=1.)
correct1=0
correct2=0
count=0
image_paths=[_[0] for _ in images]
for i in range(100):
    id=random.randint(0,len(images)-1)
    img1 = Image.open(fr'{images[id][0]}').convert('RGB')
    real_words=ocr(img_to_base64(images[id][0]))[0]
    img1=op_resize(img1)
    img1=perspective_transformer(img1)
    img=convert_tensor(img1).view(1,3,256,256)
    if use_cuda:
        img=img.cuda()
    img=op_norm(img)
    img_feature=feature_extractor(img).detach().cpu().view(1,-1).numpy()
    D, I = faiss_index[real_words].search(img_feature, 5)
    pred_ids=[I[0][i] for i in range(5)]
    distances=[D[0][i] for i in range(5)]
    # print('输入---',real_words,pred_ids[0])
    if fr'{images[id][0]}' in image_index[real_words]:
        if image_index[real_words].index(fr'{images[id][0]}')  in pred_ids:
            correct1+=1        
        if image_index[real_words].index(fr'{images[id][0]}')  == pred_ids[0]:
            correct2+=1
    else:
        continue
    count+=1
print('top4准确率：',correct1/count)
print('top1准确率：',correct2/count)

