In [2]:
from akie import MetaLearner, setup_seed

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn import ModuleList as ML
from torch.nn import ParameterList as PL
import numpy as np

import time, pickle
from sklearn.metrics import roc_auc_score

def get_data():
    i_genre = np.load('/home2/zh/data/ml-1m/i_genre.npy')
    i_other = np.load('/home2/zh/data/ml-1m/i_other.npy')
    ui = np.load('/home2/zh/data/ml-1m/ui.npy')
    x_genre = np.load('/home2/zh/data/ml-1m/x_genre.npy')
    x_other = np.load('/home2/zh/data/ml-1m/x_other.npy')
    y = np.load('/home2/zh/data/ml-1m/y.npy')
    return i_genre, i_other, ui, x_genre, x_other, y

def run0(model, data, user, phase, spt_qry_split=None):
    i_genre, i_other, ui, x_genre, x_other, y = data
    idx = np.where(ui[:,0]==user)[0]
    if phase == 'train':
        if len(idx) <= 10: return None
    onehot_i = torch.tensor(i_other[idx,:-1]).cuda()
    onehot_x = torch.tensor(x_other[idx,:-1]).cuda()
    multihot_i = torch.tensor(i_genre[idx]).cuda()
    multihot_x = torch.tensor(x_genre[idx]).cuda()
    multihot_list = [(multihot_i, multihot_x)]
    ctns = torch.tensor(x_other[idx,-1:]).cuda()
    task_data = (onehot_i, onehot_x, multihot_list, ctns)
    label = torch.tensor(y[idx]).cuda().double()
    spt_qry_perm = torch.randperm(len(idx))
    if phase == 'train':
        if spt_qry_split == 'max(1/8, 4)':
            sqsplit = max(4, int(len(idx)/8))
        else:
            raise Exception('Undeifined spt_qry_split')
        spt_idx = spt_qry_perm[:sqsplit]
        qry_idx = spt_qry_perm[-sqsplit:]
        if (1-label[spt_idx]).sum() <= 0: return None
        if (1-label[qry_idx]).sum() <= 0: return None
    else:
        spt_idx = spt_qry_perm[:0]
        qry_idx = spt_qry_perm[:]
        model.eval()
    pred, loss = model(task_data, label, spt_idx, qry_idx, phase=phase)
    if phase != 'train':
        model.train()
    return pred, loss, label[qry_idx]

torch.cuda.set_device(5)
setup_seed(81192)
data = get_data()
i_genre, i_other, ui, x_genre, x_other, y = data
config = {'num_embeddings': 3529,
        'embedding_dim': 16,
        'dim': [16*6+1, 64, 1],
        'dropout': [0, 0],
        'embedding_dim_meta': 32,
        'userl2': 32*4+0,
        'cluster_d': 128,
        'clusternum': [1, 3, 2, 1],
        'user_part': ([0,1,2,3], [], []),
        'inner_steps': 2,
        'batchsize': 256,
        'learning_rate': [0.5, 0.5],
        'update_lr': [0.01, (0, 0)]}
config['spt_qry_split'] = 'max(1/8, 4)'
config['batchnum'] = int(75000*32/config['batchsize'])
config['eval_every'] = int(config['batchnum']/400)
model = MetaLearner(config).cuda().double()
user_set = np.array(list(set(ui[:,0])))
usernum = len(user_set)
perm = np.random.permutation(usernum)
train_usernum = int(0.8*usernum)
valid_usernum = int(0.9*usernum)
optimizer = optim.SGD([{'params': model.net.parameters(), 'lr': config['learning_rate'][0]},
                        {'params': list(model.parameters())[len(list(model.net.parameters())):], 'lr': config['learning_rate'][1]}])
traucs, vaaucs, trloss, valoss = [], [], [], []
labels_chkp, preds_chkp, losses_chkp = [], [], []

In [29]:
import time
from tqdm import tqdm
train_usernum, valid_usernum

(4832, 5436)

In [55]:
def val(num):
    model.load_state_dict(torch.load('test.try'))
    preds, losses, labels = [], [], []
    x = perm[5436:]
    np.random.shuffle(x)
    t = time.clock()
#     with torch.no_grad():
    for user in tqdm(user_set[x[:num]]):
        pred, loss, label = run0(model, data, user, 'valid')
        preds.append(pred)
        losses.append(loss)
        labels.append(label)
    print(time.clock()-t)
    valid_auc = roc_auc_score(np.array(torch.cat(labels).detach().cpu()), np.array(torch.cat(preds).detach().cpu()))
    valid_loss = torch.stack(losses).mean().item()
    print(valid_auc, valid_loss)

In [57]:
val(604)

100%|██████████| 604/604 [00:03<00:00, 152.93it/s]

4.006118999999998
0.6081545958477983 0.6698012204261237



