In [None]:

import torch
import numpy as np
import torchvision.transforms as T
import torchvision.models as models
import os
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
import faiss                   # make faiss available
index = faiss.IndexFlatL2(2048)   # build the index
print(index.is_trained)
import glob
images=glob.glob('data1/*.jpg')
print(len(images))


In [None]:
feature_extractor = torch.nn.Sequential(*list(models.resnext101_32x8d(pretrained=True).cuda().children())[:-1])
jitter = T.ColorJitter(brightness=.5, hue=.3)
rotater = T.RandomRotation(degrees=(0, 180))
op_resize=T.Resize(size=(256,256))
op_norm=T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
   )
convert_tensor = T.ToTensor()
for image_path in images:
    img = Image.open(image_path)
    img=op_resize(img)
    img=convert_tensor(img).view(1,3,256,256).cuda()
    img=op_norm(img)
    img_feature=feature_extractor(img)
    index.add(img_feature.detach().cpu().view(1,-1).numpy())                  # add vectors to the index

In [None]:
perspective_transformer = T.RandomPerspective(distortion_scale=0.4, p=1.0)
import random
id=random.randint(0,len(images))
img1 = Image.open(images[id])
img1=op_resize(img1)
# img1=jitter(img1)
# img1=rotater(img1)
img1=perspective_transformer(img1)
img=convert_tensor(img1).view(1,3,256,256).cuda()
img=op_norm(img)
img_feature=feature_extractor(img).detach().cpu().view(1,-1).numpy()
D, I = index.search(img_feature, 3)
print(I,D)
# plt.imshow(img1)
def get_concat_h(im1,img):
    height= im1.height if im1.height > img.height else img.height
    dst = Image.new('RGB', (im1.width + img.width, height))
    dst.paste(im1, (0, 0))
    dst.paste(img, (im1.width, 0))
    return dst
dst=Image.open(images[I[0][0]]).convert("RGB")
for i in range(1,3):
    img = Image.open(images[I[0][i]]).convert("RGB")
    dst=get_concat_h(dst,img)
plt.imshow(dst)
plt.show()
plt.imshow(img1)
plt.show()

In [None]:
#测试代码
perspective_transformer = T.RandomPerspective(distortion_scale=0.2, p=.8)
correct=0
count=0
length=len(images)-1
for i in range(200):
    id=random.randint(0,length)
    img1 = Image.open(images[id])
    img1=op_resize(img1)
    img1=perspective_transformer(img1)
    img=convert_tensor(img1).view(1,3,256,256).cuda()
    img=op_norm(img)
    img_feature=feature_extractor(img).detach().cpu().view(1,-1).numpy()
    D, I = index.search(img_feature, 2)
    pred_id=I[0][0]
    if id==pred_id:
        correct+=1
    count+=1
print(f'准确率：{correct/count}')
