In [1]:
!pip install 'neptune-contrib[monitoring]' &> /dev/null
from neptunecontrib.monitoring.pytorch_lightning import NeptuneLogger

In [2]:
import collections

def flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

In [3]:
import pandas as pd
import numpy as np
import torch
from matplotlib import pyplot as plt
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
neptune_key = user_secrets.get_secret("neptune")
torch_device = 'cpu'
if hasattr(torch._C, '_cuda_getDeviceCount'):
    torch_device = 'cuda'

In [4]:
train = pd.read_parquet('../input/movielens20m-v2/train.parquet').astype('int32')
val = pd.read_parquet('../input/movielens20m-v2/val.parquet').astype('int32')
test = pd.read_parquet('../input/movielens20m-v2/test.parquet').astype('int32')

In [5]:
# MAX_USER_ID = 10000
# train = train[train.userId < MAX_USER_ID]
# val = val[val.userId < MAX_USER_ID]
# test = test[test.userId < MAX_USER_ID]

In [6]:
NUM_USER = max(train.userId.max(), val.userId.max(), test.userId.max()) + 1 
NUM_ITEM = max(train.movieId.max(), val.movieId.max(), test.movieId.max()) + 1

In [7]:
import pickle
user = pickle.load(open('../input/movielens20m-v2/user.pickle', 'rb'))

In [8]:
NUM_ITEM = max(NUM_ITEM, max(map(max, user)) + 1)


In [9]:
NUM_USER

129797

In [10]:
MAX_USER_ITEM = int(np.mean(list(map(len, user)))) * 2
print(MAX_USER_ITEM)

76


In [11]:
users = np.zeros((len(user), MAX_USER_ITEM)).astype('int32')
for i, v in enumerate(user):
    for j, u in enumerate(reversed(v)):
        if j >= MAX_USER_ITEM:
            break
        users[i, j] = u + 1


In [12]:
users.shape

(129797, 76)

In [13]:
print(len(user))
print(NUM_USER)
print(NUM_ITEM)

129797
129797
20709


In [14]:
from torch.utils.data import Dataset, DataLoader
from numpy.random import randint
class UserMovieDataset(Dataset):
    def __init__(self, df, user, neg_sample=0, batch_size=1, num_items=1):
        self.neg_sample = neg_sample
        self.query = df.userId.to_numpy()
        self.item = df.movieId.to_numpy()
        self.batch_size = batch_size
        self.length = len(df)
        self.user = user
        self.num_items = num_items
        
    def __len__(self):
        return self.length // self.batch_size
        
    def __getitem__(self, idx):
        pos = randint(0, self.length, self.batch_size)
        q1 = self.query[pos]
        i1 = self.item[pos]
        l1 = np.ones_like(i1)
        neg = randint(0, self.length, self.batch_size * self.neg_sample)
        q2 = self.query[neg]
        i2 = randint(0, self.num_items, self.batch_size * self.neg_sample)
        l2 = np.zeros_like(i2)
        q = np.hstack([q1, q2])
        
        return self.user[np.hstack([q1, q2])], np.hstack([i1, i2]), np.hstack([l1, l2])
    
class UserMovieEvalDataset(Dataset):
    def __init__(self, df, eval_df, user, batch_size=1000000):
        a = eval_df.userId.drop_duplicates()
        b = df.movieId.drop_duplicates()
        c = pd.merge(a, b, how = 'cross')
        eval_df['label'] = 1.0
        d = pd.merge(c, eval_df, how = 'outer', on=['userId', 'movieId'])
        d['label'] = d['label'].fillna(0)
        self.d = d
        self.query = d.userId.to_numpy()
        self.item = d.movieId.to_numpy()
        self.label = d.label.to_numpy()
        self.batch_size = batch_size
        self.user = user
        self.length = (len(d) + self.batch_size - 1) // self.batch_size
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        q1 = self.query[idx * self.batch_size: (idx + 1)* self.batch_size ]
        i1 = self.item[idx * self.batch_size: (idx + 1)* self.batch_size ]
        l1 = self.label[idx * self.batch_size: (idx + 1)* self.batch_size ]
        return q1, self.user[q1], i1, l1

In [15]:
UserMovieDataset(train, users, neg_sample=10, num_items=NUM_ITEM)[0]

(array([[ 1491,   732,  1689,  1838,   895,   709,  1712,   258,   524,
          1124,  1822,  2387,   548,   886,   538,   316,  1221,  1563,
          1187,  1668,     1,    50,  1057,   294,   588,  2222,  1170,
          1172,  2450,  1016,  1807,   291,   353,   454,  2269,  1570,
          1074,  1660,    47,  1192,  2200,   603,  2109,  1107,  1183,
          1249,  1196,  1241,  1236,   583,  2264,   377,  1112,  1876,
          2303,  1262,  1624,   584,   590,  1193,  1175,  1361,  1067,
          1217,  2280,  1229,  1556,  2108,  1220,   477,   361,   765,
          1188,  1671,  1371,  1456],
        [ 2333,  2109,  1074,  1918,  2646,   843,   603,   316,  3017,
           588,   902,  1269,  1912,  1180,  1247,  2650,   505,  1355,
          1230,  1883,   622,  1303,  1876,  3001,  2702,  3008,  1984,
          1314,  2892,   474,  1100,  2623,  3018,   244,  2822,  2146,
          2341,  1008,  3007,     0,     0,     0,     0,     0,     0,
             0,     0,    

In [16]:
import pickle
import numbers
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F
import math
import torchmetrics

In [17]:
class ExtraLoss:
    _payload = None
    def __init__(self, **kwargs):
        super(ExtraLoss, self).__init__(**kwargs)
        
    def _clear_loss(self):
        ExtraLoss._payload = None
        
    def total_loss(self):
        v = ExtraLoss._payload
        self._clear_loss()
        return v
        
    def add_loss(self, v):
        if ExtraLoss._payload is None:
            ExtraLoss._payload = v
        else:
            ExtraLoss._payload += v

In [18]:
class ActivePayload:
    _prob = {}
    def __init__(self, **kwargs):
        super(ActivePayload, self).__init__(**kwargs)
        
    def track_prob(self, k, v):
        ActivePayload._prob[k] = torch.max(v, 1, keepdim=True)[0]
        
    def get_prob(self, k):
        if k not in ActivePayload._prob:
            return None
        return ActivePayload._prob[k]

In [19]:
class LoggerPayload:
    _payload = {}
    def __init__(self, **kwargs):
        super(LoggerPayload, self).__init__(**kwargs)
        
    def clear(self):
        LoggerPayload._payload = {}
        
    def dump(self, prefix):
        for k, v in LoggerPayload._payload.items():
            self.log(f'{prefix}_{k}', np.mean(v))
        self.clear()
        
    def track_metric(self, k, v):
        if k not in LoggerPayload._payload:
            LoggerPayload._payload[k] = []
        LoggerPayload._payload[k].append(v)

In [20]:
class DNN(nn.Module):
    def __init__(self, shape, use_bn=False, dropout=None, activation=None, **kwargs):
        super(DNN, self).__init__()
        layers = []
        for i, (prev, nxt) in enumerate(zip(shape, shape[1:])):
            layers.append(nn.Linear(prev, nxt))
            if use_bn:
                layers.append(nn.BatchNorm1d(nxt))
            if activation == 'relu':
                layers.append(nn.ReLU())
            if activation == 'tanh':
                layers.append(nn.Tanh())
            if activation == 'sigmoid':
                layers.append(nn.Sigmoid())
            if activation == 'prelu':
                layers.append(nn.PReLU())
            if dropout:
                layers.append(nn.Dropout(p=dropout))
        self.dnn = nn.Sequential(*layers)
    
    def forward(self, inputs):
        x = inputs['x']
        return {'x': self.dnn(x)}

In [21]:
class RDN(nn.Module):
    def __init__(self, use_bn=False, width=16, depth=1, activation=None, **kwargs):
        super(RDN, self).__init__()
        self.use_bn = use_bn
        self.bn_layers = []
        self.act_layers = []
        self.fc_layers = []
        self.depth = depth
        for i in range(depth):
            if use_bn:
                bn = nn.BatchNorm1d(width)
                self.bn_layers.append(bn)
                setattr(self, f'bn_{i}', bn)
            fc = nn.Linear(width, width)
            self.fc_layers.append(fc)
            setattr(self, f'fc_{i}', fc)
            if activation == 'relu':
                self.act_layers.append(nn.ReLU())
            if activation == 'tanh':
                self.act_layers.append(nn.Tanh())
            if activation == 'sigmoid':
                self.act_layers.append(nn.Sigmoid())
            if activation == 'silu':
                self.act_layers.append(nn.SiLU())
    
    def forward(self, inputs):
        prev = inputs['x']
        res = []
        for i in range(self.depth):
            nxt = self.fc_layers[i](prev)
            for r in res:
                nxt += r
            if self.act_layers:
                nxt = self.act_layers[i](nxt)
            if self.bn_layers:
                nxt = self.bn_layers[i](nxt)
            res.append(nxt)
            prev = nxt
        return {'x': prev}

In [22]:
class DotProduct(nn.Module, LoggerPayload):
    def __init__(self, norm_query=False, norm_item=False, track_key=None, **kwargs):
        super(DotProduct, self).__init__()
        self.norm_query = norm_query
        self.norm_item = norm_item
        self.track_key = track_key
        
    def forward(self, inputs):
        query = inputs['query']
        item = inputs['item']
        self._track_embedding(query, item)
        if self.norm_query:
            query = nn.functional.normalize(query)
        if self.norm_item:
            item = nn.functional.normalize(item)
        dot = query * item
        logit = torch.sum(dot, dim=(1,), keepdim=True)
        ret = {
            'query': query,
            'item': item,
            'dot': dot,
            'logit': logit
        }
        return ret
    
    def _track_embedding(self, query, item):
        if self.track_key is None:
            return
        query = query[torch.randint(len(query), (100,))]  
        item = item[torch.randint(len(item), (100,))]
        query = query / (query + 0.0001)
        item = item / (item + 0.0001)
        prod = torch.mm(query, item.T)
        query_zero = torch.count_nonzero(prod, dim=1).double()
        item_zero = torch.count_nonzero(prod, dim=0).double()
        self.track_metric(f'{self.track_key}_query_prod', torch.mean(torch.mean(prod, dim=1)).item())
        self.track_metric(f'{self.track_key}_query_ret', torch.mean(query_zero).item())
        self.track_metric(f'{self.track_key}_item_prod', torch.mean(torch.mean(prod, dim=0)).item())
        self.track_metric(f'{self.track_key}_item_ret', torch.mean(item_zero).item())




In [23]:
class QueryDNN(nn.Module, LoggerPayload):
    def __init__(self, dot=None, dnn=None, track_key=None, **kwargs):
        super(QueryDNN, self).__init__()
        if dot is None:
            self.dot = DotProduct()
        else:
            self.dot = bootstrap(dot)
        if dnn is None:
            self.dnn = DNN([1024, 1])
        else:
            self.dnn = bootstrap(dnn)
        self.track_key = track_key
        
    def forward(self, inputs):
        query = inputs['query']
        ret = self.dot(inputs)
        dot = ret['dot']
        emb = torch.hstack((query, dot))
        dot_logit = ret['logit']
        dnn_logit = self.dnn({'x': emb})['x']
        ret['logit'] = dnn_logit
        return ret


In [24]:
class PassThrough(nn.Module):
    def __init__(self, **kwargs):
        super(PassThrough, self).__init__()
    
    def forward(self, inputs):
        return inputs

In [25]:
class DSSM(nn.Module):
    def __init__(self, query_tower, item_tower, mixer, **kwargs):
        super(DSSM, self).__init__()
        self.query_tower = bootstrap(query_tower)
        self.item_tower = bootstrap(item_tower)
        self.mixer = bootstrap(mixer)
        
    def forward(self, inputs):
        mixer_inputs = {
            'query': self.query_tower({'x': inputs['query']})['x'],
            'item': self.item_tower({'x':inputs['item']})['x'],
        }
        ret = self.mixer(mixer_inputs)
        return ret

In [26]:
class MultiDSSM(nn.Module):
    def __init__(self, first_tower, next_tower, **kwargs):
        super(MultiDSSM, self).__init__()
        self.first_tower = bootstrap(first_tower)
        self.next_tower = bootstrap(next_tower)
        
    def forward(self, inputs):
        ret = {
            'logit': self.first_tower(inputs) + self.next_tower(inputs)
        }
        return ret

In [27]:
class HardConcrete(ActivePayload, ExtraLoss, nn.Module, LoggerPayload):
    def __init__(self, beta = 0.1, high = 1.1, low= -0.1, track_key = None, reg_weight = 0.0, eval_method='hard', **kwargs):
        super(HardConcrete, self).__init__()
        self.beta = beta
        self.high = high
        self.low = low
        self.track_key = track_key
        self.reg_weight = reg_weight
        self.eval_method = eval_method

    def forward(self, inputs):
        weight = self.sample_attention(inputs['x'])
        self._track_reg_loss(inputs['x'])
        self._track_embedding(weight)
        self._track_prob(torch.sigmoid(inputs['x']/self.beta))
        return {
            'x': weight
        }
    
    def get_prob(self, weights):
        return torch.sigmoid(weights - self.beta * math.log(- self.low / self.high))
        
    def sample_attention(self, weights):
        if self.training:
            eps = torch.rand_like(weights)
            s = torch.sigmoid((torch.log(eps) - torch.log(1.0 - eps) + weights)/self.beta)
        else:
            if self.eval_method == 'hard':
                s = torch.sigmoid(weights/0.001)
            elif self.eval_method == 'sample':
                eps = torch.rand_like(weights)
                s = torch.sigmoid((torch.log(eps) - torch.log(1.0 - eps) + weights)/0.001)
            elif self.eval_method == 'median':
                s = torch.sigmoid(weights/self.beta)
            else:
                assert False
        s = s * (self.high - self.low) + self.low
        return F.hardtanh(s, min_val=0, max_val=1)
    
    def _track_prob(self, prob):
        if self.track_key is None:
            return
        self.track_prob(self.track_key, prob)
    
    def _track_reg_loss(self, weight):
        loss = torch.max(torch.sum(self.get_prob(weight), 1)) * self.reg_weight
        self.track_metric(f'{self.track_key}_reg_loss', loss.item())
        self.add_loss(loss)      
                          
    def _track_embedding(self, weight):
        if self.track_key is None:
            return        
        self.track_metric(f'{self.track_key}_width_avg', torch.mean(torch.count_nonzero(weight, dim=1).double()).item())
        self.track_metric(f'{self.track_key}_density_avg', torch.mean(torch.count_nonzero(weight, dim=0).double()).item()/ weight.size()[0])
        self.track_metric(f'{self.track_key}_density_max', torch.max(torch.count_nonzero(weight, dim=0).double()).item()/ weight.size()[0])

        



In [28]:
class RegMeanNorm(ExtraLoss, nn.Module):
    def __init__(self, dim, decoder=None, mean=3.0, reg_weight=0.001, **kwargs):
        super(RegMeanNorm, self).__init__()
        if decoder is None:
            self.decoder = HardConcrete()
        else:
            self.decoder = bootstrap(decoder)
        self.norm = nn.BatchNorm1d(dim, affine=False)
        tensor_mean = torch.tensor(mean)
        if self.cuda:
            tensor_mean = tensor_mean.cuda()
        self.mean = nn.Parameter(tensor_mean)
        self.reg_weight = reg_weight
        
    def forward(self, inputs):
        alpha = inputs['x']
        alpha = self.norm(alpha)
        alpha = alpha + self.mean
        self.add_loss(self.reg_weight * self.mean)
        ret = self.decoder({'x': alpha})
        return ret
    

In [29]:
class MeanNorm(nn.Module):
    def __init__(self, dim, decoder=None, mean=0.0, var=1.0, **kwargs):
        super(MeanNorm, self).__init__()
        if decoder is None:
            self.decoder = HardConcrete()
        else:
            self.decoder = bootstrap(decoder)
        self.norm = nn.BatchNorm1d(dim, affine=False)
        self.mean = mean
        self.var = var
        
    def forward(self, inputs):
        alpha = inputs['x']
        alpha = self.norm(alpha)
        alpha = (alpha + self.mean)* self.var
        ret = self.decoder({'x': alpha})
        return ret

In [30]:
class ShiftSoftmax(nn.Module):
    def __init__(self, mean=0.0, **kwargs):
        super(ShiftSoftmax, self).__init__()
        self.mean = mean
        self.softmax = nn.Softmax()
        
    def forward(self, inputs):
        x = inputs['x']
        x = self.softmax(x) + self.mean
        return {'x': x}

In [31]:
class VariantionTower(nn.Module):
    def __init__(self, decoder=None, preproc=None, sampler=None, encoder=None, postproc=None, **kwargs):
        super(VariantionTower, self).__init__()
        if decoder is None:
            self.decoder = HardConcrete()
        else:
            self.decoder = bootstrap(decoder)
        if preproc is None:
            self.preproc = PassThrough()
        else:
            self.preproc = bootstrap(preproc)
        if sampler is None:
            self.sampler = PassThrough()
        else:
            self.sampler = bootstrap(sampler)
        if postproc is None:
            self.postproc = PassThrough()
        else:
            self.postproc = bootstrap(postproc)
        if encoder is None:
            self.encoder = PassThrough()
        else:
            self.encoder = bootstrap(encoder)

    def forward(self, inputs):
        ret = {}
        embedding = self.preproc(inputs)
        embedding = self.encoder(embedding)
        embedding = self.sampler(embedding)
        embedding = self.decoder(embedding)
        embedding = self.postproc(embedding)
        return embedding         

In [32]:
class SparseEmbedTower(nn.Module):
    def __init__(self, warmup=0, decoder=None, preproc=None, encoder=None, shared=None, use_stacking=False, **kwargs):
        super(SparseEmbedTower, self).__init__()
        if decoder is None:
            self.decoder = HardConcrete()
        else:
            self.decoder = bootstrap(decoder)
        if shared is None:
            self.shared = PassThrough()
        else:
            self.shared = bootstrap(shared)
        if preproc is None:
            self.preproc = PassThrough()
        else:
            self.preproc = bootstrap(preproc)
        if encoder is None:
            self.encoder = PassThrough()
        else:
            self.encoder = bootstrap(encoder)
        self.use_stacking = use_stacking
        self.warmup = warmup
        
    def forward(self, inputs):
        ret = {}
        shared = self.shared(inputs)
        if self.use_stacking:
            embedding = self.preproc(shared)
            weight = self.decoder(self.encoder(embedding))
        else:
            embedding = self.preproc(shared)
            weight = self.decoder(self.encoder(shared))
        if self.warmup > 0:
            self.warmup -= 1
            return embedding
        else:
            embedding['x'] = embedding['x'] * weight['x']
        return embedding        

In [33]:
class Reweight(ActivePayload, LoggerPayload, nn.Module):
    def __init__(self, tn_weight=1.0, tp_weight=1.0, fp_weight=1.0, fn_weight=1.0, var_weight=0.0,
                 var_weights=None, track_key = None, **kwargs):
        super(Reweight, self).__init__()
        self.tn_weight = tn_weight
        self.tp_weight = tp_weight
        self.fp_weight = fp_weight
        self.fn_weight = fn_weight
        self.var_weight = var_weight
        self.track_key = track_key
        if var_weights:
            self.var_weights = var_weights
        else:
            self.var_weights = {}
    
    def forward(self, score, label):
        with torch.no_grad():
            binary_score = F.sigmoid((score - 0.5 )* 1000)
            total_weight = self._tracked('tp_weight', binary_score * label) * self.tp_weight
            total_weight += self._tracked('fp_weight', (1.0 - binary_score) * label) * self.fp_weight
            total_weight += self._tracked('fn_weight', binary_score * (1.0 - label)) * self.fn_weight
            total_weight += self._tracked('tn_weight', (1.0 - binary_score) * (1.0 - label)) * self.tn_weight

            for k, w in self.var_weights.items():
                prob = self.get_prob(k)
                if prob is not None:
                    total_weight += self._tracked(f'{k}_weight', prob * (1.0 - prob) * w)

        return total_weight
    
    def _tracked(self, key, value):
        if self.track_key:
            self.track_metric(f'{self.track_key}_{key}', torch.mean(value).item())
        return value


In [34]:
class Recommender(ExtraLoss, LoggerPayload, nn.Module):
    def __init__(self, tower, query_emb_size, item_emb_size, num_items, use_clamp=True, neg_sample_weight=1.0, pos_clamp_thr = 1.0,
                 max_item_num=0, neg_clamp_thr=0.0, share_embedding=False, emb_agg='avg', reweight=None, **kwargs):
        super(Recommender, self).__init__()
        if share_embedding:
            assert item_emb_size == query_emb_size
            self.query_embedding = nn.Embedding(num_items + 1, query_emb_size)
            self.item_embedding = self.query_embedding
        else:
            self.query_embedding = nn.Embedding(num_items + 1, query_emb_size)
            self.item_embedding = nn.Embedding(num_items + 1, item_emb_size)
        self.tower = bootstrap(tower)
        self.output = nn.Sequential(nn.Conv1d(1, 1, 1), nn.Sigmoid())
        self.neg_clamp_thr = neg_clamp_thr
        self.pos_clamp_thr = pos_clamp_thr
        self.neg_sample_weight = neg_sample_weight
        self.emb_agg = emb_agg
        self.use_clamp = use_clamp
        if emb_agg == 'weighted':
            self.weighted_avg = nn.Conv1d(max_item_num, 1, 1)
        if reweight:
            self.reweight = bootstrap(reweight)
        else:
            self.reweight = None
        
    def forward(self, inputs):
        query = self.query_embedding(inputs['query'])
        if self.emb_agg == 'avg':
            n_words = torch.sum(query > 0, 1).float() + 1e-10
            sum_words = query.sum(1).squeeze()
            query = sum_words / n_words
        elif self.emb_agg == 'weighted':
            query = self.weighted_avg(query).squeeze()
        else:
            assert False
        ret = {
            'query': query,
            'item': self.item_embedding(inputs['item']),
        }
        logit = self.tower(ret)['logit']
        prob = self.output(logit.unsqueeze(1)).squeeze(1)
        if self.use_clamp:
            prob = torch.clamp(prob, min=0.001, max=0.999)
        labels = torch.unsqueeze(inputs['label'], 1).float()
        if self.reweight:
            weights = self.reweight(prob, labels)
            loss = F.binary_cross_entropy(prob, labels, weight = weights)
        else:
            trunked_prob = torch.max(self.neg_clamp_thr - labels, prob)
            trunked_prob = torch.min(self.pos_clamp_thr + 1.0 - labels, trunked_prob)
            weights = (1.0 - labels) * self.neg_sample_weight + labels
            loss = F.binary_cross_entropy(trunked_prob, labels, weight = weights)
        ce = F.binary_cross_entropy(prob, labels)
        extra_loss = self.total_loss()
        if extra_loss is not None:
            self.track_metric(f'extra', extra_loss.item())
            loss += extra_loss
        ret['logit'] = logit
        ret['prob'] = prob
        ret['loss'] = loss
        ret['ce'] = ce
        self.track_metric(f'loss', loss.item())
        self.track_metric(f'ce', ce.item())
        
        return ret

In [35]:
class Experiment(LoggerPayload, pl.LightningModule):
    def __init__(self, model, lr=0.01, **kwargs):
        super(Experiment, self).__init__()
        self.model = bootstrap(model)
        self.lr = lr
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def validation_step(self, batch, batch_idx):
        query_ids, query, item, label = batch
        inputs = {
            'query': query,
            'item': item,
            'label': label,
        }
        ret = self.forward(inputs)
        self.dump('val')
        return {
            'pred': np.vstack([
                np.squeeze(ret['prob'].detach().cpu().numpy()), 
                query_ids.detach().cpu().numpy(), 
                np.squeeze(label.detach().cpu().numpy())])
        }
    
    def validation_epoch_end(self, validation_step_outputs):
        preds = []
        for out in validation_step_outputs:
            preds.append(out['pred'])
        df = pd.DataFrame(np.hstack(preds).T, columns =['pred', 'query', 'label'])
        df["rank"] = df.groupby("query")["pred"].rank("first", ascending=False)
        df_ground = df[['query', 'label']].groupby("query").sum().rename(columns={"label": "total"})
        for M in [10, 20, 40]:
            df_overlap = df[df["rank"] < M + 1][['query', 'label']].groupby("query").sum().rename(columns={"label": "overlap"})
            df2 = pd.merge(df_overlap, df_ground, on="query")
            precision = np.mean(df2.overlap)/M
            recall = np.mean(df2.overlap/df2.total)
            fmeasure = 2.0 * precision * recall / (recall + precision + 0.00001)
            self.log(f'precision{M}', precision)
            self.log(f'recall{M}', recall)
            self.log(f'fmeasure{M}', fmeasure, prog_bar=True)
            
    def forward(self, inputs):
        return self.model(inputs)
        
    def training_step(self, batch, batch_idx):
        query, item, label = batch
        inputs = {
            'query': query,
            'item': item,
            'label': label,
        }
        ret = self.forward(inputs)
        self.dump('train')
        return ret['loss']
    


In [36]:
import sys, inspect
def bootstrap(config):
    for name, obj in inspect.getmembers(sys.modules[__name__]):
        if inspect.isclass(obj) and obj.__name__ == config['name']:
            return obj(**config)

In [37]:
dense_config = {
    'model': {
        'name': 'Recommender',
        'tower': {
            'name': 'DSSM',
            'query_tower': {
                'name': 'DNN',
                'shape': [24, 128, 64, 512],
                'use_bn': True,
            } , 
            'item_tower': {
                'name': 'DNN',
                'shape': [24, 128, 64, 512],
                'use_bn': True,
            },
            'mixer': {
                'name': 'DotProduct'
            },
        },
        'query_emb_size': 24, 
        'item_emb_size': 24, 
        'num_items': NUM_ITEM,
        'share_embedding': False,
        'emb_agg': 'weighted',
        'max_item_num': MAX_USER_ITEM,
    },
    'neg_sample': 4,
    'batch_size': 80000,
    'max_epochs': 50
}

In [38]:
vae_config = {
    'model': {
        'name': 'Recommender',
        'tower': {
            'name': 'DSSM',
            'query_tower': {
                'name': 'DNN',
                'shape': [24, 128, 64, 24],
                'use_bn': True,
            } , 
            'item_tower': {
                'name': 'VariantionTower',
                'preproc': {
                    'name': 'RDN',
                    'width': 24,
                    'depth': 5,
                    'use_bn': True,
                    'activation': 'silu'
                },
                'encoder': {
                    'name': 'DNN',
                    'shape': [24, 20],
                },
                'decoder': {
                    'name': 'DNN',
                    'shape': [20, 24],
                },
                'sampler': {
                    'name': 'HardConcrete',
                    'track_key': 'item',
                    'reg_weight': 0.0,
                    'beta': 0.01,
                },
                'postproc': {
                    'name': 'RDN',
                    'width': 24,
                    'depth': 5,
                    'use_bn': True,
                    'activation': 'silu'
                },
            },
            'mixer': {
                'name': 'DotProduct'
            },
        },
        'query_emb_size': 24, 
        'item_emb_size': 24, 
        'neg_sample_weight': 5,
        'num_items': NUM_ITEM,
        'share_embedding': False,
        'emb_agg': 'weighted',
        'max_item_num': MAX_USER_ITEM,
    },
    'neg_sample': 4,
    'batch_size': 80000,
    'max_epochs': 50
}

In [39]:
HIDDEN_DIM = 48
sparse_config = {
    'model': {
        'name': 'Recommender',
        'tower': {
            'name': 'DSSM',
            'query_tower': {
                'name': 'SparseEmbedTower',
                'warmup': 0,
                'shared': {
                    'name': 'PassThrough',
                },
                'preproc': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'encoder': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'decoder': {
                    'name': 'MeanNorm',
                    'mean': -1.2,
                    'dim': HIDDEN_DIM,
                    'decoder': {
                        'name': 'HardConcrete',
                        'eval_method': 'sample',
                        'track_key': 'user',
                        'beta': 0.1,
                    }
                }
            },  
            'item_tower': {
                'name': 'SparseEmbedTower',
                'shared': {
                    'name': 'PassThrough',
                },
                'warmup': 0,
                'preproc': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],

                },
                'encoder': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'decoder': {
                    'name': 'MeanNorm',
                    'mean': -1.6,
                    'dim': HIDDEN_DIM,
                    'decoder': {
                        'name': 'HardConcrete',
                        'eval_method': 'sample',
                        'track_key': 'item',
                        'beta': 0.1,
                    }
                }
            },  
            'mixer': {
                'name': 'QueryDNN',
                'dot': {
                    'name': 'DotProduct',
                    'track_key': 'prod'
                },
                'dnn': {
                    'name': 'DNN',
                    'shape': [HIDDEN_DIM * 2, 128, 64, 24, 1],
                    'activation': 'prelu',
                    'use_bn': True,
                }
            },
        },
        'reweight': {
            'name': 'Reweight',
            'fp_weight': 5.0,
            'fn_weight': 5.0,
            'tn_weight': 0.0,
            'tp_weight': 1.0,
            'track_key': 'al',
            'var_weights': {
                'user': 0.0,
                'item': 0.0,
            }
        },
        'query_emb_size': 24, 
        'item_emb_size':24, 
        'num_items': NUM_ITEM,
        'emb_agg': 'weighted',
        'max_item_num': MAX_USER_ITEM,
    },
    'neg_sample': 4,
    'batch_size': 8000,
    'max_epochs': 70
}

In [40]:
HIDDEN_DIM = 48
dot_config = {
    'model': {
        'name': 'Recommender',
        'tower': {
            'name': 'DSSM',
            'query_tower': {
                'name': 'SparseEmbedTower',
                'warmup': 0,
                'shared': {
                    'name': 'PassThrough',
                },
                'preproc': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'encoder': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'decoder': {
                    'name': 'MeanNorm',
                    'mean': -1.2,
                    'dim': HIDDEN_DIM,
                    'decoder': {
                        'name': 'HardConcrete',
                        'eval_method': 'median',
                        'track_key': 'user',
                        'beta': 0.1,
                    }
                }
            },  
            'item_tower': {
                'name': 'SparseEmbedTower',
                'shared': {
                    'name': 'PassThrough',
                },
                'warmup': 0,
                'preproc': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],

                },
                'encoder': {
                    'name': 'DNN',
                    'shape': [24, 64, HIDDEN_DIM],
                },
                'decoder': {
                    'name': 'MeanNorm',
                    'mean': -1.6,
                    'dim': HIDDEN_DIM,
                    'decoder': {
                        'name': 'HardConcrete',
                        'eval_method': 'median',
                        'track_key': 'item',
                        'beta': 0.1,
                    }
                }
            },  
            'mixer': {
                "name": "DotProduct", 
                "track_key": "prod"
            },
        },
        'reweight': {
            'name': 'Reweight',
            'fp_weight': 5.0,
            'fn_weight': 5.0,
            'tn_weight': 0.0,
            'tp_weight': 1.0,
            'track_key': 'al',
            'var_weights': {
                'user': 0.0,
                'item': 0.0,
            }
        },
        'query_emb_size': 24, 
        'item_emb_size':24, 
        'num_items': NUM_ITEM,
        'emb_agg': 'weighted',
        'max_item_num': MAX_USER_ITEM,
    },
    'neg_sample': 4,
    'batch_size': 8000,
    'max_epochs': 70
}

In [41]:
config = dot_config

In [42]:
TRAIN_DS = UserMovieDataset(train, user=users, neg_sample=config['neg_sample'], batch_size=config['batch_size'], num_items=NUM_ITEM)
loader = DataLoader(TRAIN_DS, batch_size=None, num_workers=2)
VAL_DS = UserMovieEvalDataset(train, val, user=users, batch_size=config['batch_size'] * (config['neg_sample'] + 1))
val_loader = DataLoader(VAL_DS, batch_size=None, num_workers=1)

model = Experiment(**config)
n_parameters = sum([p.data.nelement() for p in model.parameters()])
print('  + Number of params: {}'.format(n_parameters))
print(model)
config['n_parameters'] = n_parameters
neptune_logger = NeptuneLogger(
    api_key=neptune_key, 
    project_name="pickedmelon/movielens2", 
    description=str(model),
    params=flatten(config))
trainer = pl.Trainer(
    max_epochs=config['max_epochs'],
    gradient_clip_val=0.5,
    gpus=1 if torch_device == 'cuda' else 0, 
    tpu_cores=8 if torch_device == 'tpu' else None, 
    logger=neptune_logger
)#, callbacks=[EarlyStopping(monitor='val_loss')]
trainer.fit(model, loader, val_loader)

  + Number of params: 1013039
Experiment(
  (model): Recommender(
    (query_embedding): Embedding(20710, 24)
    (item_embedding): Embedding(20710, 24)
    (tower): DSSM(
      (query_tower): SparseEmbedTower(
        (decoder): MeanNorm(
          (decoder): HardConcrete()
          (norm): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        )
        (shared): PassThrough()
        (preproc): DNN(
          (dnn): Sequential(
            (0): Linear(in_features=24, out_features=64, bias=True)
            (1): Linear(in_features=64, out_features=48, bias=True)
          )
        )
        (encoder): DNN(
          (dnn): Sequential(
            (0): Linear(in_features=24, out_features=64, bias=True)
            (1): Linear(in_features=64, out_features=48, bias=True)
          )
        )
      )
      (item_tower): SparseEmbedTower(
        (decoder): MeanNorm(
          (decoder): HardConcrete()
          (norm): BatchNorm1d(48, eps=1e-05, momen

  import sys


https://ui.neptune.ai/pickedmelon/movielens2/e/MOV1-299


Validation sanity check: 0it [00:00, ?it/s]



Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]