In [1]:
import os
import pickle
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import random

In [2]:
def find_positive(simi_scores, simi_path_list):
    sum_score = 1/np.sum(simi_scores)
    probability_distribution = [v*sum_score for v in simi_scores]
    pos_filename = np.random.choice(simi_path_list, POSITIVE_MAX_NUM, replace=False, p=probability_distribution)
    return pos_filename

def find_negative_inner(neg_scores, neg_path_list):
    sum_score = 1/np.sum(neg_scores)
    probability_distribution = [v*sum_score for v in neg_scores]
    neg_filename = np.random.choice(neg_path_list, NEGATIVE_INNER_MAX_NUM, replace=False, p=probability_distribution)
    return neg_filename

def find_negative_outter(cur_category, category_info):
    category_idx_list = [v for v in range(len(category_info))]
    random.shuffle(category_idx_list)
    keys = category_info.keys()
    neg_cat_key = None
    for cate_idx in category_idx_list:
        k = keys[cate_idx]
        if cur_category != category_info[k]:
            neg_cat_key = k
            break
    neg_filelist_path = category_info[neg_cat_key]
    
    neg_f = open(os.path.join(ROOT_DATA_PATH,neg_filelist_path),'r')
    neg_filelist = neg_f.readlines()
    neg_f.close()
    neg_idx_list = [v for v in range(len(neg_filelist))]
    random.shuffle(neg_idx_list)
    neg_filename = np.array([neg_filelist[v].split('\n')[0] for v in neg_idx_list[:NEGATIVE_OUTTER_MAX_NUM]])
    return neg_filename

def make_triplet_list(anchor, positive, negative):
    assert(type(anchor)==str)
    assert(type(positive)==np.ndarray)
    assert(type(negative)==np.ndarray)
#     print(positive.shape)
    
    positive_inds = [v for v in range(len(positive))]
    negative_inds = [v for v in range(len(negative))]
    xv1, yv1 = np.meshgrid(positive_inds, negative_inds)
#     print(xv1, yv1)
    triplet_dataset = []
    for i,j in zip(xv1,yv1):
        for i_idx, ii in enumerate(i):
            triplet_dataset.append([anchor, positive[ii], negative[j[i_idx]]])
    return triplet_dataset

In [3]:
def debug_visualization(triplet_dataset):
    size = len(triplet_dataset)
    
    for triplet in triplet_dataset:        
        gs1 = gridspec.GridSpec(1, 3)
        gs1.update(wspace=0.01, hspace=0.02) # set the spacing between axes. 
        plt.figure(figsize=(3,3))
        for im_idx, im_path in enumerate(triplet):
            im = cv2.imread(os.path.join('cropped_images','Category-Attribute-Prediction-Benchmark',im_path))
            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

            ax1 = plt.subplot(gs1[im_idx])
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
            ax1.set_aspect('equal')

            plt.subplot(1,3,im_idx+1)
            if im_idx == 0:
                plt.title('anc')
            elif im_idx == 1:
                plt.title('pos')
            else:
                plt.title('neg')
            plt.imshow(im)
            plt.axis('off')
        plt.show()

In [6]:
POS_THRESHOLD_SCORE = 0.80
NEG_MAX_THRESHOLD_SCORE = 0.6
NEG_MIN_THRESHOLD_SCORE = 0.3
POSITIVE_MAX_NUM = 8#20
NEGATIVE_INNER_MAX_NUM = 4#10
NEGATIVE_OUTTER_MAX_NUM = 2#10
np.random.seed(222)
ROOT_DATA_PATH = 'vectorMatrix'
filelist = os.listdir(ROOT_DATA_PATH)
category_info = {}
for filename in filelist:
    if '_filename.txt' in filename:
        cate = filename.split('_filename.txt')[0]
#     if '_matrix.pkl' in filename:
#         cate = filename.split('_matrix.pkl')[0]
#         if cate in category_info:
#             category_info[cate].append(filename)
#         else:
        category_info[cate]=filename
        
fff = open('triplet_dataset_v2_080.txt','w')
keys = category_info.keys()
# random.shuffle(keys)
# keys=keys[:100]
for category in keys:
#     if category != 'Hoodie_shirt':
#         continue
#     print('========================================================')
#     print(category)
#     print('========================================================')
    file_order = []
    with open(os.path.join('vectorMatrix',category+'_filename.txt'),'r') as f:
        lines = f.readlines()
        for line in lines:
            file_order.append(line.split('\n')[0])
    with open(os.path.join('vectorMatrix',category+'_matrix.pkl'),'rw') as f:
        matrix = pickle.load(f)

    file_order_inds = [v for v in range(len(file_order))]
    
#     random.shuffle(file_order_inds)
#     file_order_inds = file_order_inds[:5]
    
    for file_order_ind in file_order_inds:
        anchor_filename = file_order[file_order_ind]
        scores = matrix[file_order_ind]
        
        threshold_inds = np.where(np.logical_and(scores <= 0.999999999999, scores > POS_THRESHOLD_SCORE))[0]
        simi_scores = scores[threshold_inds]
        simi_path_list = [file_order[v] for v in threshold_inds]
        
        sorted_inds = np.argsort(simi_scores, axis=0)[::-1][:POSITIVE_MAX_NUM]
        simi_scores = simi_scores[sorted_inds]
        simi_path_list = [simi_path_list[v] for v in sorted_inds]
        
        size_simi_path_list = len(simi_path_list)
        if size_simi_path_list == 0:
            continue
        
        ############
        # positive #
        ############
        if POSITIVE_MAX_NUM < len(simi_scores):
            positive_filenames = find_positive(simi_scores, simi_path_list)
        else:
            positive_filenames = np.array(simi_path_list)
        
        ##################
        # negative inner #
        ##################        
        threshold_inds = np.where(np.logical_and(scores <= NEG_MAX_THRESHOLD_SCORE, scores > NEG_MIN_THRESHOLD_SCORE))[0]
        neg_scores = scores[threshold_inds]
        neg_path_list = [file_order[v] for v in threshold_inds]
        
        sorted_inds = np.argsort(neg_scores, axis=0)[:NEGATIVE_INNER_MAX_NUM]
        neg_scores = neg_scores[sorted_inds]
        neg_path_list = [neg_path_list[v] for v in sorted_inds]
        
        if NEGATIVE_INNER_MAX_NUM < len(neg_scores):
            negative_inner_filenames = find_negative_inner(neg_scores, neg_path_list)
        else:
            negative_inner_filenames = np.array(neg_path_list)
        
        triplet_datase_inner = make_triplet_list(anchor_filename, positive_filenames, negative_inner_filenames)
        
        ###################
        # negative outter #
        ###################
        negative_outter_filenames = find_negative_outter(category, category_info)
        triplet_dataset_outter = make_triplet_list(anchor_filename, positive_filenames, negative_outter_filenames)
        
        triplet_datase = triplet_datase_inner + triplet_dataset_outter
        #debug_visualization(triplet_datase)
#         assert(False)
        for triplet in triplet_datase:
            str_triplet = ''
            if len(triplet) != 3:
                print(triplet)
                assert(False)
            for img in triplet:
                str_triplet += img + '\t'
                
            fff.write(str_triplet[:-1]+'\n')
fff.close()
        