In [1]:
import json
import numpy as np
dd = json.load(open("cs_label.json"))

In [16]:
from pathlib import Path
from tqdm import tqdm
import cv2, os
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from metrics import scheduler, padding_resize

def create_gamma_img(gamma, img):
    gamma_cvt = np.zeros((256,1), dtype=np.uint8)
    for i in range(256):
        gamma_cvt[i][0] = 255*(float(i)/255)**(1.0/gamma)
    return cv2.LUT(img, gamma_cvt)


class DataLoad:
    def __init__(self, config, cosine_annealing=True):
        self.cfg = config
        self.cosine_annealing = cosine_annealing
        self.width, self.height = config.W, config.H

    def create_callbacks(self):
        checkpoint_path = os.path.join(self.cfg.WEIGHT_DIR, "arcface_model_{epoch:02d}.hdf5")
        checkpoint_dir = os.path.dirname(checkpoint_path)

        target_monitor = 'val_loss'
        cp_callback = ModelCheckpoint(checkpoint_path, monitor=target_monitor, verbose=1, save_best_only=True, mode='min')

        reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', factor = 0.1, patience = 3, verbose = 1)
        
        tb_callback = TensorBoard(log_dir='logs/',
                                      update_freq=self.cfg.train_batch * 5,
                                      profile_batch=0)
        calllbacks = [reduce_lr, cp_callback, tb_callback]
        if self.cosine_annealing:
            min_lr = 1e-3
            calllbacks.append(scheduler.CosineAnnealingScheduler(T_max=self.cfg.epochs, eta_max=self.cfg.lr, 
                           eta_min=min_lr, verbose=1))

        return calllbacks

    def preprocess(self, p, clannel, valid=None):
        if clannel==3:
            x = cv2.imread(p)
            x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
            x = cv2.resize(x, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
            x = x.reshape(self.width, self.height, 3).astype(np.float32)
        elif clannel==1:
            x = cv2.imread(p, 0)
            x = cv2.resize(x, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
            x = x.reshape(self.width, self.height, 1).astype(np.float32)
        elif valid:
            x = cv2.imread(p)
            x = cv2.resize(x, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
            x = create_gamma_img(1.8, x)
            x = x.reshape(self.width, self.height, 3).astype(np.float32)
        return x/255

    def img_load(self, valid=False):
        X, y_labels, X_aug = [], [], []
        x1_dir = self.cfg.x_img
        x_imgs = os.listdir(x1_dir)
        x_imgs.sort()
        for i, image_path in enumerate(tqdm(x_imgs)):
            _, y, color, shape, _ = image_path.split("_")
            if valid:
                img = self.preprocess(os.path.join(x1_dir, image_path), 3, valid=True)
                X.append(img)
                y_labels.append(int(y))
            else:
                img = self.preprocess(os.path.join(x1_dir, image_path), 3, valid=None)
                aug_img = np.flip(img)
                X.append(img)
                y_labels.append(int(y))
                X_aug.append(aug_img)
        return X, y_labels, X_aug

    def meta_load(self, valid=False):
        x1_dir = self.cfg.x_img
        x_imgs = os.listdir(x1_dir)
        x_imgs.sort()
        
        X, X_aug, color_label, shape_label = [], [], [], []
                                    
        for i, image_path in enumerate(tqdm(x_imgs)):
            _, y, color, shape, _ = image_path.split("_")
            if valid:
                img = self.preprocess(os.path.join(x1_dir, image_path), 3)
            else:
                img = self.preprocess(os.path.join(x1_dir, image_path), 3)
                aug_img = np.flip(img)
                X_aug.append(aug_img)
            # img
            X.append(img)
            # x_label, y1, y2
            color_label.append(int(color))
            shape_label.append(int(shape))
        return X, X_aug, color_label, shape_label

122

In [9]:
pred_color = np.load("npy/color_sim600.npy")
pred_shape = np.load("npy/shape_sim600.npy")
cossims = np.load("npy/cossim600.npy")
print(pred_color.shape, pred_shape.shape, cossims.shape)

(122, 1) (122, 1) (122, 1)


In [14]:
np.argmax(pred_shape+pred_color)

107

In [22]:
import json
import numpy as np

src = list(np.random.rand(122))
dd = json.load(open("cs_label.json"))

confs = []
for i, (k, v) in enumerate(dd.items()):
    #print(i, v['category'], v['color'])
    sl, cl = v['category'], v['color']
    if sl==np.argmax(pred_shape):
        print(i, np.argmax(pred_color), cl)
    confs.append((pred_shape[i]*0.2)+pred_color[i])
            
print(len(confs))

122


In [23]:
embbed = [conf+val for conf, val in zip(confs, cossims)]
np.argmax(embbed)

66

In [24]:
for i, conf in enumerate(confs):
    print(i, conf)

0 [-0.1477156]
1 [0.00837296]
2 [0.34579316]
3 [-0.14134282]
4 [0.1932905]
5 [-0.051756]
6 [0.17528285]
7 [0.21936014]
8 [0.0626115]
9 [-0.30198845]
10 [0.22636962]
11 [-0.3307628]
12 [-0.3095765]
13 [-0.16245145]
14 [0.46833035]
15 [0.05316379]
16 [-0.0782958]
17 [-0.11439661]
18 [0.2790215]
19 [0.05798183]
20 [-0.07273694]
21 [0.15680245]
22 [-0.26901302]
23 [0.05137862]
24 [-0.06630017]
25 [0.12773818]
26 [0.0301045]
27 [-0.10238506]
28 [0.08562538]
29 [0.03194359]
30 [-0.20603576]
31 [-0.30501246]
32 [0.05595013]
33 [-0.15500867]
34 [0.06122018]
35 [-0.10516082]
36 [0.22743085]
37 [-0.08551983]
38 [-0.02875526]
39 [-0.00813287]
40 [0.20177984]
41 [-0.28815055]
42 [-0.06198273]
43 [-0.12186307]
44 [-0.24242666]
45 [0.16600457]
46 [-0.05184956]
47 [0.14960802]
48 [0.15704751]
49 [-0.25318635]
50 [-0.00072899]
51 [-0.20344305]
52 [0.17674802]
53 [-0.20594147]
54 [0.36672533]
55 [-0.1817515]
56 [-0.42737293]
57 [0.07109027]
58 [-0.02236333]
59 [-0.18906096]
60 [-0.0986145]
61 [0.341168

In [25]:
for i, sim in enumerate(embbed):
    print(i, sim)

0 [0.08354673]
1 [0.49708793]
2 [0.7447586]
3 [-0.87227935]
4 [-0.09029198]
5 [-0.36084297]
6 [-0.11023344]
7 [0.8065659]
8 [0.19519725]
9 [-0.7130016]
10 [-0.08067849]
11 [-0.5395456]
12 [-1.1163455]
13 [-0.581322]
14 [0.9415533]
15 [0.28540197]
16 [-0.20924255]
17 [-0.4953149]
18 [0.5788449]
19 [-0.00083416]
20 [-0.3786801]
21 [0.91441655]
22 [-1.0307572]
23 [-0.25395483]
24 [-0.1910785]
25 [-0.5055326]
26 [-0.27590704]
27 [-0.32350725]
28 [0.80940604]
29 [0.7015418]
30 [-0.72120595]
31 [-0.9840054]
32 [0.4146764]
33 [-0.85271835]
34 [0.2161165]
35 [-0.6443541]
36 [0.24379116]
37 [-0.04121071]
38 [-0.00157987]
39 [0.2535609]
40 [-0.36679363]
41 [-0.94922745]
42 [-0.28153002]
43 [-0.6873743]
44 [-1.0064137]
45 [0.22612633]
46 [0.09633476]
47 [0.19262956]
48 [0.31728512]
49 [0.15719575]
50 [0.508461]
51 [-1.071148]
52 [0.84461105]
53 [-1.0038089]
54 [0.9392015]
55 [-0.04647185]
56 [-1.030688]
57 [-0.03428504]
58 [-0.37830594]
59 [-0.8470496]
60 [-0.6157347]
61 [0.35650864]
62 [-0.18092