In [1]:
import os
import random
import pickle
import torch
import torchvision
import pandas as pd
import numpy as np

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed(123)
torch.backends.cudnn.deterministic = True

In [41]:
with open('../pickles_for_git/main_dict.pickle', 'rb') as f:
    main_dict = pickle.load(f)

In [42]:
main_dict[2]

{'artist_name': 'AWOL',
 'track_title': 'Food',
 'artist_tags': [('post-punk', '100'),
  ('hardcore', '80'),
  ('rock', '60'),
  ('punk', '60'),
  ('hardcore punk', '60'),
  ('Progressive metal', '40'),
  ('Bumblefoot', '40'),
  ('Ron Thal', '40'),
  ('metalcore', '20'),
  ('Hip-Hop', '20')],
 'track_tags': [],
 'similar_tracks': [],
 'similar_list': [],
 'genre': 'Hip-Hop',
 'all_genres': '[21]',
 'fma_tags': '[]',
 'all_tags': ['hip_hop',
  'detroit',
  'rap',
  'gangsta',
  'midwest',
  'midwest rap',
  'gangsta',
  'black',
  'black music',
  'flow']}

In [43]:
def get_score(x, y):
    answer_tags = main_dict[x]['all_tags']
    query_tags = main_dict[y]['all_tags']
    score = 0
    i = 0
    for tag in answer_tags:
        if tag in query_tags:
            i += 1
    if i == 0:
        return 0

    precision = i / len(answer_tags)
    recall = i / len(query_tags)
    score = 2 * precision * recall / (precision + recall)
    return score
        

In [44]:
get_score(190, 621)

0.7407407407407407

In [89]:
positiv_pairs = dict()
main_list = list(main_dict.keys())
for i in range(len(main_dict)):
    idx_1 = main_list[i]
    if len(main_dict[idx_1]['all_tags']) < 3:
        continue
    positiv_pairs[idx_1] = list()
    for j in range(len(main_dict)):
        idx_2 = main_list[j]
        if idx_2 == idx_1:
            continue
        if set(main_dict[idx_1]['all_genres']) != set(main_dict[idx_2]['all_genres']):
            continue
        score = get_score(idx_1, idx_2)
        if score >= 0.4:
            positiv_pairs[idx_1].append(idx_2)
    if len(positiv_pairs[idx_1]) == 0:
        positiv_pairs.pop(idx_1)
            
            
        
            

In [91]:
len(positiv_pairs)

6560

In [96]:
with open('../pickles_for_git/positiv.p', 'wb') as f:
    pickle.dump(positiv_pairs, f)

In [97]:
negativ_pairs = dict()
main_list = list(main_dict.keys())
for i in range(len(main_dict)):
    idx_1 = main_list[i]
    if len(main_dict[idx_1]['all_tags']) < 3:
        continue
    negativ_pairs[idx_1] = list()
    for j in range(len(main_dict)):
        idx_2 = main_list[j]
        if idx_2 == idx_1:
            continue
        if len(set(main_dict[idx_1]['all_tags']) & set(main_dict[idx_2]['all_tags'])) == 0:
            negativ_pairs[idx_1].append(idx_2)
    if len(negativ_pairs[idx_1]) == 0:
        negativ_pairs.pop(idx_1)

In [98]:
len(negativ_pairs)

6843

In [99]:
with open('../pickles_for_git/negativ.p', 'wb') as f:
    pickle.dump(negativ_pairs, f)

In [100]:
df_train = pd.read_csv('../data/csv/train.csv')
df_train.shape

(7897, 2049)

In [103]:
df_train.sample(5)

Unnamed: 0,id,0,1,2,3,4,5,6,7,8,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
4395,44780,0.021842,3.180349,3.496266,3.066512,2.382225,1.643121,2.91235,3.459743,0.107625,...,1.150135,9.767879,4.304694,0.760868,2.013256,2.232111,0.0,4.238208,2.349661,1.868011
6896,143303,0.0,0.0,7.167964,0.0,0.0,0.0,0.0,0.0,0.0,...,0.050632,0.225793,5.762378,1.829636,3.967636,2.907128,0.0,2.200015,1.308771,2.882554
5979,127996,0.577016,0.0,4.992959,1.591914,0.385836,0.0,0.0,0.0,0.0,...,0.0,0.0,6.769474,1.655986,4.63858,4.655748,0.0,3.559631,2.663797,5.568622
854,98549,6.648037,9.589069,7.298048,2.744189,3.121323,0.419662,4.266093,4.569069,0.545057,...,6.130232,5.913578,0.0,0.0,4.360894,0.0,0.0,5.975849,0.102119,1.756808
912,99364,1.581861,2.432273,5.68658,1.292741,0.429636,0.02454,0.915842,2.869539,0.521775,...,0.143558,8.438079,2.256808,3.168467,3.419915,3.590786,0.0,4.209648,2.885381,2.650276


In [104]:
train_idx = df_train.id.values

In [107]:
train_idx

array([     5,     10,    140, ..., 153337, 154303, 154306])

In [108]:
np.random.choice(train_idx, 1, replace=False)[0]

136322

In [141]:
i = 0
triplet = list()
while i < 10000:
    query_id = np.random.choice(train_idx, 1, replace=False)[0]
    try:
        p_id = np.random.choice(positiv_pairs[query_id], 1, replace=False)[0]
        n_id = np.random.choice(negativ_pairs[query_id], 1, replace=False)[0]
    except KeyError:
        continue
    
    if p_id not in train_idx or n_id not in train_idx:
        continue
    if (query_id, p_id, n_id) not in triplet:
        triplet.append((query_id, p_id, n_id))
        i += 1

In [142]:
len(triplet)

10000

In [143]:
len(set([x[0] for x in triplet]))

5112

In [144]:
triplet[0]

(141139, 111793, 46842)

In [145]:
def find_img_path(idx):
    name = f'{idx:06}.png'
    base_path = '../data/spectrograms/train/'
    for address, dirs, files in os.walk(base_path):
        for img in files:
            if img == name:
                img_path = address + '/' + img
                return img_path
    print(name)

In [147]:
trpl_with_path = [(find_img_path(x[0]), find_img_path(x[1]), find_img_path(x[2])) for x in triplet]

In [148]:
trpl_with_path[0]

('../data/spectrograms/train/train/Instrumental/141139.png',
 '../data/spectrograms/train/train/Instrumental/111793.png',
 '../data/spectrograms/train/train/Pop/046842.png')

In [152]:
with open('../pickles_for_git/triplets.p', 'wb') as f:
    pickle.dump(trpl_with_path, f)

In [153]:
with open('../pickles_for_git/triplets.p', 'rb') as f:
    tripls = pickle.load(f)

In [154]:
tripls[0]

('../data/spectrograms/train/train/Instrumental/141139.png',
 '../data/spectrograms/train/train/Instrumental/111793.png',
 '../data/spectrograms/train/train/Pop/046842.png')