In [1]:
import json
import torch
import numpy as np
from models import model as AE
from torch.utils.data import  Dataset, DataLoader
from tqdm import tqdm_notebook
import copy

In [2]:
with open('./dataset/val.json', encoding='UTF-8') as json_file:
    data = json.load(json_file)

In [54]:
len(data)

23015

In [3]:
class InferDataset(Dataset):    

    def __init__(self, data, max_song):
        
        self.data = data
        self.max_song = max_song
        self.vector = np.zeros(max_song, dtype=np.float32)
        
    def get_vector(self, vec):
        cop_vector = copy.deepcopy(self.vector)
        if vec:
            cop_vector[vec] = 1            
        return cop_vector        
    
    def __getitem__(self, idx):
        
        x = self.data[idx]
        song = x['songs']
        
        vec = self.get_vector(song)
        
        return str(x['id']), vec
        
    def __len__(self):
        return len(self.data)


In [4]:
"""
[
    {
        'id': 플레이리스트 id,
        'songs': 예측한 곡 100개 리스트,
        'tags': 예측한 태그 10개 리스트
    },
    ... (n개)
     results.json
]
"""

"\n[\n    {\n        'id': 플레이리스트 id,\n        'songs': 예측한 곡 100개 리스트,\n        'tags': 예측한 태그 10개 리스트\n    },\n    ... (n개)\n     results.json\n]\n"

In [64]:
checkpoint = torch.load("./dataset/model/model_5.pth")
model = AE.AE(707989)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [65]:
dataset = InferDataset(data, 707989) 
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

In [66]:
def rcmd_rank(x, output):
    target = x == 1
    output[target] = -99999
#     output_sort_ind = output.argsort()

    result_args = np.argpartition(-output, (1, 101))
    return result_args[:,:100]

In [67]:
results = []
model.eval()
for idx, (m_id, x) in enumerate(tqdm_notebook(dataloader)): #model.train()
    
    output = model(x).data.numpy()
    rcmd_ind = rcmd_rank(x, output)
    results += [{ "id":int(_id), "songs": x.tolist(), "tags" :['피아노', '이루마', '메로디', '상큼한', '출근길', '트렌디', '팝', '신나는', "아이유", "BTS"]} for _id, x in zip(m_id, rcmd_ind)]

In [75]:
x[0]

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [72]:
output[0]

array([-0.36983344,  0.25769582,  0.04547113, ...,  1.1993856 ,
        0.19470575,  0.02063469], dtype=float32)

In [70]:
rcmd_ind

array([[384438, 513236, 522440, 312194, 587821, 364301, 129221, 677156,
        340954,   1150, 439393, 231447, 190711, 504762, 214882,  56476,
        550489, 175287, 522229, 484944,   2287, 268713,  14645, 487174,
        405968, 465841, 170125, 243551, 469049, 514321, 427955, 362831,
        678433, 239628, 664854, 656635, 393590, 501061, 197782, 549851,
        229363, 552610, 237630,  35728, 620574, 398868, 605733,  49685,
        390942, 145158, 405892, 141157,    749, 272157, 353715, 169171,
           760,  91394, 108379, 613272, 361066, 266501, 558878, 528737,
        201030, 571106, 255113, 113047, 293735, 379173, 152072, 225733,
        229848, 121454, 567796, 164365, 184864, 668076, 421264, 399587,
        105656, 595945, 135060, 637638, 641387, 203978, 593704, 173628,
         16566, 692009, 361411, 330611, 208106, 121521,  58747, 498913,
        252229, 210314, 577389, 174250],
       [364301, 513236,    760, 362831, 243551, 210526,  91394, 567796,
        393590, 492342,

In [76]:
for x in range(len(results)):
    results[x]['id'] = int(results[x]['id'])
    
    

In [62]:
results[x]['id']

65189

In [77]:
with open("results.json", "w") as json_file:
    json.dump(results, json_file)
