In [65]:
import h5py
import numpy as np
import pandas as pd

In [23]:
dataset = h5py.File('multimodal_imdb.hdf5', 'r')


In [24]:
list(dataset.keys())

['features',
 'genres',
 'images',
 'imdb_ids',
 'sequences',
 'three_grams',
 'vgg_features',
 'word_grams']

In [95]:
def jaccard(a, b):
    a = np.array(a)  
    b = np.array(b)
    
    x = a * b

    inter = np.sum(x, 1)
    union = (np.sum(a, 1) + np.sum(b, 1)) - inter
    
    j = inter / union
    
    dis = 1 - j
    return dis
    

In [93]:
a = dataset['genres'][0:1]
b = dataset['genres'][0:3]
print(a)
print(b)

[[1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1]]
[[1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1]
 [1 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
 [1 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1]]


In [94]:
jaccard(a, b)

Inter [4 2 4]
Union [4 6 5]


array([0.        , 0.66666667, 0.2       ])

In [187]:
def gen_triplet(n_batch, n_trip, filename):
    df = pd.DataFrame({'id': dataset['imdb_ids'], 'index': np.arange(0, len(dataset['imdb_ids']))})
    
    triplets = []
    for i in range(n_trip):
        while (True):
            batch = np.random.choice(dataset['imdb_ids'], n_batch, replace = False)

            b = df.loc[df['id'].isin(batch)]
            index = np.array(b['index'])

            sub_g = dataset['genres'][index]

            anchor = sub_g[0:1]
            sub_y = sub_g[1:]
            j = jaccard(anchor, sub_y)
            close = np.min(j)
            far = np.max(j)

            #Discuss threshold
            if close < 0.25 and far > 0.5:
                anchor_id = index[0]
                pos_id = index[np.argmin(j)+1]
                neg_id = index[np.argmax(j)+1]
                triplet = [anchor_id, pos_id, neg_id]
                triplets.append(triplet)
                break

    triplets = np.array(triplets)
    np.save(filename, triplets)

In [193]:
n_batch = 100
gen_triplet(n_batch, 100, "triplets100.npy")

In [197]:
print(dataset['genres'][40])
print(dataset['genres'][166])
print(dataset['genres'][235])

[1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
[1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
[0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0]


In [195]:
triplets = np.load("triplets100.npy")
triplets

array([[  343, 18337,   580],
       [   40,   166,   235],
       [   35,    50,   361],
       [   22,    64,   250],
       [   82,    93,   593],
       [  488,   590,  1203],
       [  145, 15581,  1120],
       [   51, 21582,   311],
       [   78, 18259,  1012],
       [  829, 18514,  1509],
       [  617, 19062,  3002],
       [   40,    50,   466],
       [   10,    35,  4368],
       [    2, 18168,   343],
       [   59,   155,  1489],
       [   51,  6912,   578],
       [   81,   119,   452],
       [   25,   160,   412],
       [  103, 18191,   746],
       [  105,  6551,  1413],
       [  119, 15580,   282],
       [   59,    94,  1555],
       [  436, 18309,  1678],
       [  182, 15565,   207],
       [    2,   121,   889],
       [   80, 23201,   658],
       [  558, 15684,  1190],
       [   69,   134,   333],
       [   80, 10483,   264],
       [  111,   154,   351],
       [   39, 18249,   350],
       [    9,   194,   433],
       [  585,   823,   980],
       [  