In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Preprocess and train self-supervised rotation classifier

In [None]:
import cv2
import pandas as pd

In [None]:
from fastai.vision import *
from repr.models.resnet import resnet_vec
from repr.models.encoders import Encoder
from repr.search.input_utils import init_transforms
from repr.search import indexer

## Generate dataset

In [None]:
path = Path('data') 
weights = path / 'weights' / 'a-rot2-f2.pth'
label_engine = path / 'label_engine'
src = label_engine / 'ground_true'
dst = label_engine / 'db' / 'vectors.pkl'
qur = label_engine / 'query_cropped'
gt = label_engine / 'gt.csv'

In [None]:
head = nn.Sequential(nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(2048, 4))
backbone = resnet_vec('resnet50', head=head, weights=None)

In [None]:
backbone = nn.Sequential(model[0], model[1][:-1])
transforms = init_transforms(h=512, w=512, percnt=0.1, crop_center=True)
encoder = Encoder(backbone, transforms)
encoder.vec = lambda x: encoder(x)

In [None]:
encoder

In [None]:
indexer.index_dir(encoder, src, dst)

In [None]:
paths = indexer.img_paths(qur)

In [None]:
result_data = indexer.search_dir(encoder, paths, dst)

In [None]:
df = pd.read_csv(gt)

In [None]:
df[df.product_id > 1122]

In [None]:
product_id = df[(df.omniaz_id == 15003)].values[0][0].item()

In [None]:
type(product_id)

In [None]:
img, reslt, pt = result_data[1]
plt.figure(figsize=(10, 10))
plt.title(f'Query')
plt.imshow(img)
plt.show()
omniaz_id = int(pt.stem)
product_id = df[(df.omniaz_id == omniaz_id)].values[0][0].item()
print(product_id)
print('Results')
for ds, pt in reslt:
    res_img = cv2.imread(str(pt), cv2.IMREAD_ANYCOLOR)
    plt.figure(figsize=(10, 10))
    result_id = int(pt.stem)
    print(result_id == product_id)
    plt.title(f'{ds} {pt.stem}')
    plt.imshow(res_img)
    plt.show()

In [None]:
for img, reslt in result_data:
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    print('Results')
    for ds, pt in reslt:
        res_img = cv2.imread(str(pt), cv2.IMREAD_ANYCOLOR)
        plt.figure(figsize=(10, 10))
        plt.title(f'{ds}')
        plt.imshow(res_img)