In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
rating_df = pd.read_csv("data/ml-25m/ratings.csv")
rating_df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,296,5.0,1147880044
1,1,306,3.5,1147868817
2,1,307,5.0,1147868828
3,1,665,5.0,1147878820
4,1,899,3.5,1147868510


In [3]:
movies_df = pd.read_csv("data/ml-25m/movies.csv")

In [4]:
movies_df.head()

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy


In [5]:
no_items = movies_df['movieId'].max()

In [6]:
rating_df.shape

(25000095, 4)

In [7]:
no_users = rating_df['userId'].max()

In [8]:
# sort dataset by id and timestamp
rating_df = rating_df.sort_values(['userId', 'timestamp'])

In [9]:
# split dataset in train and eval set

In [9]:
def func1(df):
    n_rows = df.shape[0]
    df['percentile'] = np.linspace(0, 100, n_rows)
    return df

rating_df = rating_df.groupby('userId').apply(func1).reset_index(drop=True)

In [10]:
train_size = 80 # percent
train_data = rating_df[rating_df['percentile'] <= train_size]
test_data = rating_df[rating_df['percentile'] > train_size]

In [11]:
train_data.shape, test_data.shape

((19968907, 5), (5031188, 5))

In [12]:
del train_data['percentile'], test_data['percentile']

In [13]:
train_data.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,5952,4.0,1147868053
1,1,2012,2.5,1147868068
2,1,2011,2.5,1147868079
3,1,1653,4.0,1147868097
4,1,1250,4.0,1147868414


In [14]:
train_data['rating'].value_counts()

rating
4.0    5346555
3.0    3892905
5.0    3006138
3.5    2476949
4.5    1743999
2.0    1297363
2.5     976141
1.0     611705
1.5     310786
0.5     306366
Name: count, dtype: int64

In [15]:
# assuming rating more than 3 means user like the movie else he/she didn't
rating_threshold = 3
from data import Data
from tqdm import tqdm

In [16]:
trainds = Data(train_data)
testds = Data(test_data)

In [101]:
batch_size = 2**14
n_workers = 4

traindl = torch.utils.data.DataLoader(trainds, batch_size=batch_size, shuffle=True, num_workers=1)

testdl = torch.utils.data.DataLoader(testds, batch_size=batch_size, shuffle=True, num_workers=1)
print(batch_size)

16384


In [160]:
item_feature_name = "movieId"
user_feature_name = "userId"

class ItemTower(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.features = config['item']['features']
        if item_feature_name in self.features:
            self.item_module = nn.Embedding(config['item']['n_items'], config['item']['emb'])
        else:
            self.item_module = nn.Identity()
        self.mlp = nn.Sequential(
            nn.Linear(config['item']['emb'], config['item']['emb']),
            nn.GELU(),
            nn.LayerNorm(config['item']['emb'])
        )
        
    def forward(self, batch):
        item_emb = self.item_module(batch[item_feature_name].long())
        return self.mlp(item_emb)


class UserTower(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.features = config['user']['features']
        if user_feature_name in self.features:
            self.user_module = nn.Embedding(config['user']['n_users'], config['user']['emb'])
        else:
            self.item_module = nn.Identity()
        self.mlp = nn.Sequential(
            nn.Linear(config['user']['emb'], config['user']['emb']),
            nn.GELU(),
            nn.LayerNorm(config['user']['emb'])
        )
        
    def forward(self, batch):
        user_emb = self.user_module(batch[user_feature_name].long())
        return self.mlp(user_emb)

class CombinerTower(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.item_tower = ItemTower(config)
        self.user_tower = UserTower(config)
        user_dim, item_dim = config['user']['emb'], config['item']['emb']
        self.mlp = nn.Sequential(
            nn.Linear(user_dim+item_dim, config['emb']),
            nn.GELU(),
            nn.LayerNorm(config['emb'])
        )
        self.tasks = nn.ModuleDict({
            task_name: nn.Linear(config['emb'], out) for task_name, out in config['tasks']
        })
        
    def forward(self, batch):
        item_emb = self.item_tower(batch)
        user_emb = self.user_tower(batch)
        x = torch.concat([user_emb, item_emb], axis=1)
        x = self.mlp(x)
        res = {}
        for task_name, mod in self.tasks.items():
            res[task_name] = mod(x)
        return res

config = {
    'user':{
        'features':[
            'userId'
        ],
        'n_users':no_users + 2**10,
        'emb': 64
    },
    'item':{
        'features':[
            'movieId'
        ],
        'n_items':no_items + 2**10,
        'emb': 64
    },
    'emb':64,
    'tasks':[
        ['rating', 1]
    ]
}

In [148]:
model = CombinerTower(config)

In [134]:
from torchmetrics.classification import average_precision, ROC, Recall, Accuracy

class Metric:
    def __init__(self, device=torch.device("cpu"), task="binary"):
        device = device
        self.counter = 0
        self._metric = {
            "map": average_precision.AveragePrecision(task=task),
            "accuracy": Accuracy(task=task),
            "recall": Recall(task=task)
        }
        self._accelerator = {
            "map": 0,
            "accuracy": 0,
            "recall": 0
        }
        
    def __call__(self, output, actual) -> dict:
        out = {}
        for metric_name, mod in self._metric.items():
            out[metric_name] = mod(output, actual)
            self._accelerator[metric_name] += out[metric_name]
        self.counter += 1
        return out
    
    def get(self):
        return {
            k:v/self.counter for k,v in self._accelerator.items()
        }

    def to(self, device):
        for i in self._metric:
            self._metric[i].to(device)
        return self

In [135]:
metric = Metric()
output, actual = torch.randn(20, 1).view(-1), torch.randint(0, 2, size=(20, 1)).view(-1)
print(actual, output)
print(metric(output, actual))
metric.get()

tensor([0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1]) tensor([-0.4142,  0.9052, -1.1952,  1.1407,  0.1014, -1.4244, -0.9388, -0.7567,
         0.1958, -0.9591,  0.6681,  0.6016, -0.3790, -2.0306,  0.4814, -0.2623,
         1.6520,  0.8246, -0.3324,  0.3609])
{'map': tensor(0.5258), 'accuracy': tensor(0.6000), 'recall': tensor(0.6250)}


{'map': tensor(0.5258), 'accuracy': tensor(0.6000), 'recall': tensor(0.6250)}

In [136]:
class Trainer:
    def __init__(self,
                model: torch.nn.Module,
                train_dataloader: torch.utils.data.DataLoader,
                eval_dataloader: torch.utils.data.DataLoader = None,
                loss_fn = torch.nn.functional.mse_loss,
                epochs: int = 5,
                optimizer_clz = torch.optim.SGD,
                optim_params: dict = {'lr':1e-2},
                device: torch.device = torch.device("cpu"),
                metric_collection = None,
                verbose = 1,
                log_step = 500,
                *args, **kwargs
                ):
        self.log_step = log_step
        self.verbose = verbose
        self.model = model.to(device)
        self.train_dl = train_dataloader
        self.eval_dl = eval_dataloader
        self.loss_fn = loss_fn
        self.optimizer = optimizer_clz(self.model.parameters(), **optim_params)
        self.epochs = epochs
        self.device = device
        self.metric_collection = metric_collection.to(device) if metric_collection else None

    def compute_loss(self, batch, output):
        loss = {}
        for k in output:
            if k in batch:
                y = (batch[k] >= rating_threshold).float()
                loss[k] = self.loss_fn(output[k], y.view(-1, 1))
        return loss

    def compute_metric(self, batch, output, show_tensor=False):
        res = {}
        for k in output:
            if k in batch:
                y = (batch[k] >= rating_threshold).long()
                if show_tensor:
                    print(batch[k].sum(), batch[k].shape[0])
                res[k] = self.metric_collection(output[k].view(-1), y.view(-1))
        return res
    
    def set_output(self, _iter, loss, avg_loss, train=True):
        _iter.set_description(
            "{}: step loss: {:.3f}, avg_loss: {:.3f}".format(
                "Train" if train else "Eval",
                loss,
                avg_loss
            )
        )
    def train_epoch(self):
        self.model = self.model.train()
        
        _iter = tqdm(self.train_dl)
        n_batches = len(self.train_dl)
        total_loss = 0
        for _i, batch in enumerate(_iter):
            self.optimizer.zero_grad()
            output = self.model(batch)
            loss = self.compute_loss(batch, output)
            # multiple task we need better way to gives weights, ryt now it is uniform
            loss = sum([loss[l] for l in loss])/(len(loss) + 1e-6)
            
            loss.backward()
            self.optimizer.step()
            total_loss += loss.cpu().item()
            
            self.set_output(_iter, loss.item(), total_loss/(_i+1), train=True)
            show_tensor = False
            if self.verbose > 0 \
                and _i % self.log_step == 0 and _i != 0:
                print(self.metric_collection.get())
                show_tensor = True
            with torch.no_grad():
                self.compute_metric(batch, output, show_tensor)
        return {
            'loss': total_loss/(n_batches),
            'mode': 'Train'
        }
    
    @torch.no_grad()
    def eval_epoch(self):
        self.model = self.model.eval()
        
        _iter = tqdm(self.eval_dl)
        n_batches = len(self.eval_dl)
        total_loss = 0
        for _i, batch in enumerate(_iter):
            output = self.model(batch)
            loss = self.compute_loss(batch, output)
            # multiple task we need better way to gives weights, ryt now it is uniform
            loss = sum([loss[l] for l in loss])/(len(loss) + 1e-6)
            
            total_loss += loss.cpu().item()
            self.set_output(_iter, loss.item(), total_loss/(_i+1), train=False)
            
            show_tensor = False
            if self.verbose > 0 \
                and _i % self.log_step == 0 and _i != 0:
                print(self.metric_collection.get())
                show_tensor = True
            self.compute_metric(batch, output, show_tensor=show_tensor)
                
        return {
            'loss': total_loss/(n_batches),
            'mode': 'Eval'
        }

    def fit(self):
        for epoch in range(1, self.epochs+1):
            print("EPOCH:", epoch)
            self.train_epoch()
            eval_output = self.eval_epoch()
            print(eval_output)

In [137]:
trainer = Trainer(
    model=model, train_dataloader=traindl, eval_dataloader=testdl, 
    loss_fn=F.binary_cross_entropy_with_logits,
    metric_collection = Metric(),
    log_step = 100
)

In [138]:
trainer.fit()

EPOCH: 1


Train: step loss: 0.491, avg_loss: 0.562:   8%|▊         | 101/1219 [01:58<19:36,  1.05s/it]

{'map': tensor(0.8256), 'accuracy': tensor(0.7400), 'recall': tensor(0.8694)}
tensor(58244.5000, dtype=torch.float64) 16384


Train: step loss: 0.479, avg_loss: 0.524:  16%|█▋        | 201/1219 [03:44<17:53,  1.05s/it]

{'map': tensor(0.8257), 'accuracy': tensor(0.7804), 'recall': tensor(0.9319)}
tensor(58206.5000, dtype=torch.float64) 16384


Train: step loss: 0.473, avg_loss: 0.509:  25%|██▍       | 301/1219 [05:29<16:12,  1.06s/it]

{'map': tensor(0.8263), 'accuracy': tensor(0.7948), 'recall': tensor(0.9542)}
tensor(58090.5000, dtype=torch.float64) 16384


Train: step loss: 0.479, avg_loss: 0.500:  33%|███▎      | 401/1219 [07:13<14:11,  1.04s/it]

{'map': tensor(0.8267), 'accuracy': tensor(0.8022), 'recall': tensor(0.9656)}
tensor(58084., dtype=torch.float64) 16384


Train: step loss: 0.468, avg_loss: 0.494:  41%|████      | 501/1219 [09:04<13:30,  1.13s/it]

{'map': tensor(0.8273), 'accuracy': tensor(0.8067), 'recall': tensor(0.9724)}
tensor(58205., dtype=torch.float64) 16384


Train: step loss: 0.462, avg_loss: 0.491:  49%|████▉     | 601/1219 [10:57<11:36,  1.13s/it]

{'map': tensor(0.8279), 'accuracy': tensor(0.8096), 'recall': tensor(0.9770)}
tensor(58154.5000, dtype=torch.float64) 16384


Train: step loss: 0.470, avg_loss: 0.487:  58%|█████▊    | 701/1219 [13:00<12:20,  1.43s/it]

{'map': tensor(0.8284), 'accuracy': tensor(0.8118), 'recall': tensor(0.9803)}
tensor(58077., dtype=torch.float64) 16384


Train: step loss: 0.473, avg_loss: 0.485:  66%|██████▌   | 801/1219 [14:54<07:52,  1.13s/it]

{'map': tensor(0.8290), 'accuracy': tensor(0.8134), 'recall': tensor(0.9828)}
tensor(58027.5000, dtype=torch.float64) 16384


Train: step loss: 0.469, avg_loss: 0.483:  74%|███████▍  | 901/1219 [16:47<06:01,  1.14s/it]

{'map': tensor(0.8295), 'accuracy': tensor(0.8147), 'recall': tensor(0.9847)}
tensor(57979.5000, dtype=torch.float64) 16384


Train: step loss: 0.459, avg_loss: 0.482:  82%|████████▏ | 1001/1219 [18:40<04:08,  1.14s/it]

{'map': tensor(0.8299), 'accuracy': tensor(0.8157), 'recall': tensor(0.9862)}
tensor(58527., dtype=torch.float64) 16384


Train: step loss: 0.468, avg_loss: 0.480:  90%|█████████ | 1101/1219 [20:36<02:31,  1.28s/it]

{'map': tensor(0.8304), 'accuracy': tensor(0.8165), 'recall': tensor(0.9875)}
tensor(58160., dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.479:  99%|█████████▊| 1201/1219 [22:35<00:23,  1.29s/it]

{'map': tensor(0.8308), 'accuracy': tensor(0.8172), 'recall': tensor(0.9885)}
tensor(58189.5000, dtype=torch.float64) 16384


Train: step loss: 0.467, avg_loss: 0.479: 100%|██████████| 1219/1219 [23:02<00:00,  1.13s/it]
Eval: step loss: 0.492, avg_loss: 0.494:  33%|███▎      | 101/308 [02:00<04:16,  1.24s/it]

{'map': tensor(0.8298), 'accuracy': tensor(0.8165), 'recall': tensor(0.9895)}
tensor(56964., dtype=torch.float64) 16384


Eval: step loss: 0.490, avg_loss: 0.494:  65%|██████▌   | 201/308 [03:55<02:12,  1.24s/it]

{'map': tensor(0.8288), 'accuracy': tensor(0.8158), 'recall': tensor(0.9903)}
tensor(56999., dtype=torch.float64) 16384


Eval: step loss: 0.495, avg_loss: 0.494:  98%|█████████▊| 301/308 [05:53<00:08,  1.20s/it]

{'map': tensor(0.8279), 'accuracy': tensor(0.8153), 'recall': tensor(0.9909)}
tensor(56934.5000, dtype=torch.float64) 16384


Eval: step loss: 0.500, avg_loss: 0.494: 100%|██████████| 308/308 [06:04<00:00,  1.18s/it]


{'loss': 0.494243679011797, 'mode': 'Eval'}
EPOCH: 2


Train: step loss: 0.470, avg_loss: 0.466:   8%|▊         | 101/1219 [02:05<21:57,  1.18s/it]

{'map': tensor(0.8284), 'accuracy': tensor(0.8158), 'recall': tensor(0.9915)}
tensor(57890.5000, dtype=torch.float64) 16384


Train: step loss: 0.467, avg_loss: 0.466:  16%|█▋        | 201/1219 [04:06<22:14,  1.31s/it]

{'map': tensor(0.8289), 'accuracy': tensor(0.8163), 'recall': tensor(0.9920)}
tensor(58132., dtype=torch.float64) 16384


Train: step loss: 0.463, avg_loss: 0.465:  25%|██▍       | 301/1219 [06:06<18:11,  1.19s/it]

{'map': tensor(0.8294), 'accuracy': tensor(0.8168), 'recall': tensor(0.9924)}
tensor(58103.5000, dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.465:  33%|███▎      | 401/1219 [08:07<17:18,  1.27s/it]

{'map': tensor(0.8299), 'accuracy': tensor(0.8172), 'recall': tensor(0.9928)}
tensor(58106., dtype=torch.float64) 16384


Train: step loss: 0.460, avg_loss: 0.465:  41%|████      | 501/1219 [10:02<13:40,  1.14s/it]

{'map': tensor(0.8303), 'accuracy': tensor(0.8175), 'recall': tensor(0.9932)}
tensor(58275.5000, dtype=torch.float64) 16384


Train: step loss: 0.457, avg_loss: 0.465:  49%|████▉     | 601/1219 [11:52<10:45,  1.04s/it]

{'map': tensor(0.8308), 'accuracy': tensor(0.8179), 'recall': tensor(0.9935)}
tensor(58340., dtype=torch.float64) 16384


Train: step loss: 0.463, avg_loss: 0.465:  58%|█████▊    | 701/1219 [13:38<09:03,  1.05s/it]

{'map': tensor(0.8311), 'accuracy': tensor(0.8182), 'recall': tensor(0.9938)}
tensor(58215., dtype=torch.float64) 16384


Train: step loss: 0.467, avg_loss: 0.465:  66%|██████▌   | 801/1219 [15:25<07:15,  1.04s/it]

{'map': tensor(0.8315), 'accuracy': tensor(0.8185), 'recall': tensor(0.9941)}
tensor(58177.5000, dtype=torch.float64) 16384


Train: step loss: 0.473, avg_loss: 0.465:  74%|███████▍  | 901/1219 [17:12<05:34,  1.05s/it]

{'map': tensor(0.8319), 'accuracy': tensor(0.8187), 'recall': tensor(0.9943)}
tensor(57921., dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.465:  82%|████████▏ | 1001/1219 [19:00<03:57,  1.09s/it]

{'map': tensor(0.8323), 'accuracy': tensor(0.8189), 'recall': tensor(0.9945)}
tensor(58115., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.465:  90%|█████████ | 1101/1219 [20:48<02:05,  1.06s/it]

{'map': tensor(0.8326), 'accuracy': tensor(0.8191), 'recall': tensor(0.9947)}
tensor(58383., dtype=torch.float64) 16384


Train: step loss: 0.463, avg_loss: 0.465:  99%|█████████▊| 1201/1219 [22:36<00:18,  1.05s/it]

{'map': tensor(0.8330), 'accuracy': tensor(0.8193), 'recall': tensor(0.9949)}
tensor(58136., dtype=torch.float64) 16384


Train: step loss: 0.462, avg_loss: 0.465: 100%|██████████| 1219/1219 [23:01<00:00,  1.13s/it]
Eval: step loss: 0.496, avg_loss: 0.492:  33%|███▎      | 101/308 [01:57<03:47,  1.10s/it]

{'map': tensor(0.8326), 'accuracy': tensor(0.8190), 'recall': tensor(0.9951)}
tensor(56854., dtype=torch.float64) 16384


Eval: step loss: 0.485, avg_loss: 0.492:  65%|██████▌   | 201/308 [03:56<02:14,  1.25s/it]

{'map': tensor(0.8323), 'accuracy': tensor(0.8186), 'recall': tensor(0.9953)}
tensor(57040., dtype=torch.float64) 16384


Eval: step loss: 0.504, avg_loss: 0.492:  98%|█████████▊| 301/308 [05:57<00:08,  1.25s/it]

{'map': tensor(0.8319), 'accuracy': tensor(0.8182), 'recall': tensor(0.9955)}
tensor(56577., dtype=torch.float64) 16384


Eval: step loss: 0.498, avg_loss: 0.492: 100%|██████████| 308/308 [06:09<00:00,  1.20s/it]


{'loss': 0.4917071965801251, 'mode': 'Eval'}
EPOCH: 3


Train: step loss: 0.463, avg_loss: 0.463:   8%|▊         | 101/1219 [02:11<23:22,  1.25s/it]

{'map': tensor(0.8322), 'accuracy': tensor(0.8183), 'recall': tensor(0.9956)}
tensor(58049., dtype=torch.float64) 16384


Train: step loss: 0.458, avg_loss: 0.463:  16%|█▋        | 201/1219 [04:19<22:07,  1.30s/it]

{'map': tensor(0.8326), 'accuracy': tensor(0.8186), 'recall': tensor(0.9958)}
tensor(58219., dtype=torch.float64) 16384


Train: step loss: 0.475, avg_loss: 0.463:  25%|██▍       | 301/1219 [06:26<19:44,  1.29s/it]

{'map': tensor(0.8329), 'accuracy': tensor(0.8187), 'recall': tensor(0.9959)}
tensor(58041., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.463:  33%|███▎      | 401/1219 [08:35<17:33,  1.29s/it]

{'map': tensor(0.8332), 'accuracy': tensor(0.8189), 'recall': tensor(0.9960)}
tensor(58194.5000, dtype=torch.float64) 16384


Train: step loss: 0.467, avg_loss: 0.463:  41%|████      | 501/1219 [10:48<16:24,  1.37s/it]

{'map': tensor(0.8336), 'accuracy': tensor(0.8190), 'recall': tensor(0.9961)}
tensor(58081.5000, dtype=torch.float64) 16384


Train: step loss: 0.458, avg_loss: 0.463:  49%|████▉     | 601/1219 [13:01<13:27,  1.31s/it]

{'map': tensor(0.8339), 'accuracy': tensor(0.8192), 'recall': tensor(0.9962)}
tensor(58282., dtype=torch.float64) 16384


Train: step loss: 0.464, avg_loss: 0.463:  58%|█████▊    | 701/1219 [15:12<11:08,  1.29s/it]

{'map': tensor(0.8342), 'accuracy': tensor(0.8194), 'recall': tensor(0.9963)}
tensor(58140.5000, dtype=torch.float64) 16384


Train: step loss: 0.464, avg_loss: 0.463:  66%|██████▌   | 801/1219 [17:22<08:58,  1.29s/it]

{'map': tensor(0.8345), 'accuracy': tensor(0.8195), 'recall': tensor(0.9964)}
tensor(58015.5000, dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.463:  74%|███████▍  | 901/1219 [19:33<06:52,  1.30s/it]

{'map': tensor(0.8348), 'accuracy': tensor(0.8196), 'recall': tensor(0.9965)}
tensor(58031.5000, dtype=torch.float64) 16384


Train: step loss: 0.464, avg_loss: 0.463:  82%|████████▏ | 1001/1219 [21:43<04:46,  1.32s/it]

{'map': tensor(0.8351), 'accuracy': tensor(0.8197), 'recall': tensor(0.9966)}
tensor(58120.5000, dtype=torch.float64) 16384


Train: step loss: 0.472, avg_loss: 0.462:  90%|█████████ | 1101/1219 [23:52<02:32,  1.29s/it]

{'map': tensor(0.8354), 'accuracy': tensor(0.8199), 'recall': tensor(0.9967)}
tensor(57982., dtype=torch.float64) 16384


Train: step loss: 0.457, avg_loss: 0.462:  99%|█████████▊| 1201/1219 [26:03<00:23,  1.30s/it]

{'map': tensor(0.8357), 'accuracy': tensor(0.8200), 'recall': tensor(0.9968)}
tensor(58350.5000, dtype=torch.float64) 16384


Train: step loss: 0.467, avg_loss: 0.462: 100%|██████████| 1219/1219 [26:32<00:00,  1.31s/it]
Eval: step loss: 0.494, avg_loss: 0.490:  33%|███▎      | 101/308 [02:12<04:24,  1.28s/it]

{'map': tensor(0.8355), 'accuracy': tensor(0.8197), 'recall': tensor(0.9968)}
tensor(56917.5000, dtype=torch.float64) 16384


Eval: step loss: 0.491, avg_loss: 0.490:  65%|██████▌   | 201/308 [04:17<02:14,  1.26s/it]

{'map': tensor(0.8353), 'accuracy': tensor(0.8194), 'recall': tensor(0.9969)}
tensor(56892., dtype=torch.float64) 16384


Eval: step loss: 0.486, avg_loss: 0.490:  98%|█████████▊| 301/308 [06:23<00:08,  1.27s/it]

{'map': tensor(0.8351), 'accuracy': tensor(0.8191), 'recall': tensor(0.9970)}
tensor(57042., dtype=torch.float64) 16384


Eval: step loss: 0.537, avg_loss: 0.490: 100%|██████████| 308/308 [06:35<00:00,  1.29s/it]


{'loss': 0.4904913707793533, 'mode': 'Eval'}
EPOCH: 4


Train: step loss: 0.466, avg_loss: 0.462:   8%|▊         | 101/1219 [02:14<23:03,  1.24s/it]

{'map': tensor(0.8353), 'accuracy': tensor(0.8192), 'recall': tensor(0.9971)}
tensor(58046., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.462:  16%|█▋        | 201/1219 [04:33<20:25,  1.20s/it]

{'map': tensor(0.8356), 'accuracy': tensor(0.8193), 'recall': tensor(0.9971)}
tensor(58098., dtype=torch.float64) 16384


Train: step loss: 0.468, avg_loss: 0.462:  25%|██▍       | 301/1219 [06:45<18:02,  1.18s/it]

{'map': tensor(0.8359), 'accuracy': tensor(0.8194), 'recall': tensor(0.9972)}
tensor(57959.5000, dtype=torch.float64) 16384


Train: step loss: 0.456, avg_loss: 0.462:  33%|███▎      | 401/1219 [08:38<15:55,  1.17s/it]

{'map': tensor(0.8362), 'accuracy': tensor(0.8195), 'recall': tensor(0.9972)}
tensor(57979.5000, dtype=torch.float64) 16384


Train: step loss: 0.465, avg_loss: 0.462:  41%|████      | 501/1219 [10:30<12:32,  1.05s/it]

{'map': tensor(0.8364), 'accuracy': tensor(0.8196), 'recall': tensor(0.9973)}
tensor(58072., dtype=torch.float64) 16384


Train: step loss: 0.459, avg_loss: 0.462:  49%|████▉     | 601/1219 [12:27<12:19,  1.20s/it]

{'map': tensor(0.8367), 'accuracy': tensor(0.8197), 'recall': tensor(0.9973)}
tensor(58150., dtype=torch.float64) 16384


Train: step loss: 0.465, avg_loss: 0.461:  58%|█████▊    | 701/1219 [14:22<10:02,  1.16s/it]

{'map': tensor(0.8369), 'accuracy': tensor(0.8198), 'recall': tensor(0.9974)}
tensor(58044.5000, dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.461:  66%|██████▌   | 801/1219 [16:15<07:20,  1.05s/it]

{'map': tensor(0.8372), 'accuracy': tensor(0.8199), 'recall': tensor(0.9974)}
tensor(57999.5000, dtype=torch.float64) 16384


Train: step loss: 0.450, avg_loss: 0.461:  74%|███████▍  | 901/1219 [18:09<06:21,  1.20s/it]

{'map': tensor(0.8375), 'accuracy': tensor(0.8200), 'recall': tensor(0.9975)}
tensor(58275., dtype=torch.float64) 16384


Train: step loss: 0.459, avg_loss: 0.461:  82%|████████▏ | 1001/1219 [20:02<03:57,  1.09s/it]

{'map': tensor(0.8377), 'accuracy': tensor(0.8201), 'recall': tensor(0.9975)}
tensor(58114., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.461:  90%|█████████ | 1101/1219 [21:55<02:18,  1.17s/it]

{'map': tensor(0.8380), 'accuracy': tensor(0.8202), 'recall': tensor(0.9976)}
tensor(58189.5000, dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.461:  99%|█████████▊| 1201/1219 [23:58<00:24,  1.38s/it]

{'map': tensor(0.8382), 'accuracy': tensor(0.8203), 'recall': tensor(0.9976)}
tensor(58146., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.461: 100%|██████████| 1219/1219 [24:28<00:00,  1.20s/it]
Eval: step loss: 0.494, avg_loss: 0.490:  33%|███▎      | 101/308 [02:07<03:55,  1.14s/it]

{'map': tensor(0.8381), 'accuracy': tensor(0.8200), 'recall': tensor(0.9977)}
tensor(56781.5000, dtype=torch.float64) 16384


Eval: step loss: 0.484, avg_loss: 0.489:  65%|██████▌   | 201/308 [03:57<01:59,  1.12s/it]

{'map': tensor(0.8380), 'accuracy': tensor(0.8198), 'recall': tensor(0.9977)}
tensor(57226., dtype=torch.float64) 16384


Eval: step loss: 0.488, avg_loss: 0.489:  98%|█████████▊| 301/308 [05:48<00:08,  1.15s/it]

{'map': tensor(0.8378), 'accuracy': tensor(0.8196), 'recall': tensor(0.9977)}
tensor(56937., dtype=torch.float64) 16384


Eval: step loss: 0.476, avg_loss: 0.489: 100%|██████████| 308/308 [05:59<00:00,  1.17s/it]


{'loss': 0.4892668734703745, 'mode': 'Eval'}
EPOCH: 5


Train: step loss: 0.467, avg_loss: 0.461:   8%|▊         | 101/1219 [01:58<21:44,  1.17s/it]

{'map': tensor(0.8381), 'accuracy': tensor(0.8197), 'recall': tensor(0.9978)}
tensor(57876.5000, dtype=torch.float64) 16384


Train: step loss: 0.459, avg_loss: 0.460:  16%|█▋        | 201/1219 [03:52<19:06,  1.13s/it]

{'map': tensor(0.8383), 'accuracy': tensor(0.8198), 'recall': tensor(0.9978)}
tensor(58305.5000, dtype=torch.float64) 16384


Train: step loss: 0.456, avg_loss: 0.461:  25%|██▍       | 301/1219 [05:46<16:41,  1.09s/it]

{'map': tensor(0.8385), 'accuracy': tensor(0.8198), 'recall': tensor(0.9978)}
tensor(58308.5000, dtype=torch.float64) 16384


Train: step loss: 0.460, avg_loss: 0.460:  33%|███▎      | 401/1219 [07:40<14:26,  1.06s/it]

{'map': tensor(0.8388), 'accuracy': tensor(0.8199), 'recall': tensor(0.9979)}
tensor(58034.5000, dtype=torch.float64) 16384


Train: step loss: 0.457, avg_loss: 0.460:  41%|████      | 501/1219 [09:34<12:49,  1.07s/it]

{'map': tensor(0.8390), 'accuracy': tensor(0.8200), 'recall': tensor(0.9979)}
tensor(58201.5000, dtype=torch.float64) 16384


Train: step loss: 0.459, avg_loss: 0.460:  49%|████▉     | 601/1219 [11:28<11:43,  1.14s/it]

{'map': tensor(0.8393), 'accuracy': tensor(0.8201), 'recall': tensor(0.9979)}
tensor(57958.5000, dtype=torch.float64) 16384


Train: step loss: 0.462, avg_loss: 0.460:  58%|█████▊    | 701/1219 [14:49<21:11,  2.45s/it]

{'map': tensor(0.8395), 'accuracy': tensor(0.8201), 'recall': tensor(0.9980)}
tensor(58122., dtype=torch.float64) 16384


Train: step loss: 0.466, avg_loss: 0.460:  66%|██████▌   | 801/1219 [18:16<14:32,  2.09s/it]

{'map': tensor(0.8397), 'accuracy': tensor(0.8202), 'recall': tensor(0.9980)}
tensor(58113.5000, dtype=torch.float64) 16384


Train: step loss: 0.453, avg_loss: 0.460:  74%|███████▍  | 901/1219 [22:16<08:35,  1.62s/it]

{'map': tensor(0.8399), 'accuracy': tensor(0.8202), 'recall': tensor(0.9980)}
tensor(58440.5000, dtype=torch.float64) 16384


Train: step loss: 0.458, avg_loss: 0.460:  82%|████████▏ | 1001/1219 [25:09<07:02,  1.94s/it]

{'map': tensor(0.8402), 'accuracy': tensor(0.8203), 'recall': tensor(0.9981)}
tensor(58213., dtype=torch.float64) 16384


Train: step loss: 0.461, avg_loss: 0.460:  90%|█████████ | 1101/1219 [27:28<02:34,  1.31s/it]

{'map': tensor(0.8404), 'accuracy': tensor(0.8204), 'recall': tensor(0.9981)}
tensor(58170.5000, dtype=torch.float64) 16384


Train: step loss: 0.455, avg_loss: 0.460:  99%|█████████▊| 1201/1219 [29:41<00:23,  1.29s/it]

{'map': tensor(0.8406), 'accuracy': tensor(0.8204), 'recall': tensor(0.9981)}
tensor(58221.5000, dtype=torch.float64) 16384


Train: step loss: 0.454, avg_loss: 0.460: 100%|██████████| 1219/1219 [30:11<00:00,  1.49s/it]
Eval: step loss: 0.491, avg_loss: 0.489:  33%|███▎      | 101/308 [02:14<04:25,  1.28s/it]

{'map': tensor(0.8405), 'accuracy': tensor(0.8203), 'recall': tensor(0.9981)}
tensor(56809.5000, dtype=torch.float64) 16384


Eval: step loss: 0.488, avg_loss: 0.489:  65%|██████▌   | 201/308 [04:22<02:25,  1.36s/it]

{'map': tensor(0.8404), 'accuracy': tensor(0.8201), 'recall': tensor(0.9982)}
tensor(56953.5000, dtype=torch.float64) 16384


Eval: step loss: 0.492, avg_loss: 0.488:  98%|█████████▊| 301/308 [06:35<00:09,  1.31s/it]

{'map': tensor(0.8403), 'accuracy': tensor(0.8199), 'recall': tensor(0.9982)}
tensor(56835., dtype=torch.float64) 16384


Eval: step loss: 0.495, avg_loss: 0.488: 100%|██████████| 308/308 [06:47<00:00,  1.32s/it]

{'loss': 0.4884836112524008, 'mode': 'Eval'}





#### Evaluation Metric

`{'map': tensor(0.8403), 'accuracy': tensor(0.8199), 'recall': tensor(0.9982)}`

In [139]:
torch.save(model.state_dict(), "./checkpoint/baseline_chk_v1.pt")

In [143]:
model.item_tower({"movieId":torch.tensor([100]).view(1,1)})

torch.Size([1, 1, 64])

In [144]:
#### Generate embeddings

1

In [159]:
import torch
import pandas as pd

In [161]:
model = CombinerTower(config)

model.load_state_dict(torch.load("./checkpoint/baseline_chk_v1.pt"))

<All keys matched successfully>

In [162]:
movie_df = pd.read_csv("./data/ml-25m/movies.csv")

In [163]:
movie_df.head()

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy


In [164]:
movie_df.shape

(62423, 3)

In [169]:
emb_dim = 64
with torch.no_grad():
    movie_df['embedding'] = movie_df['movieId'].map(lambda x: model.item_tower({"movieId": torch.tensor([x])}).view(emb_dim))

In [177]:
emb = torch.stack(movie_df['embedding'].values.tolist(), axis=0)

In [184]:
emb = emb/(emb.norm(dim=1).view(-1,1) + 1e-7)

In [190]:
movie_df['norm_embedding'] = emb.unbind(dim=0)

In [191]:
movie_df.head()

Unnamed: 0,movieId,title,genres,embedding,norm_embedding
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy,"[tensor(-0.6897), tensor(1.1606), tensor(0.659...","[tensor(-0.0859), tensor(0.1445), tensor(0.082..."
1,2,Jumanji (1995),Adventure|Children|Fantasy,"[tensor(-0.8724), tensor(0.3376), tensor(1.183...","[tensor(-0.1086), tensor(0.0420), tensor(0.147..."
2,3,Grumpier Old Men (1995),Comedy|Romance,"[tensor(-0.6817), tensor(-0.0421), tensor(-0.7...","[tensor(-0.0848), tensor(-0.0052), tensor(-0.0..."
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance,"[tensor(-0.6069), tensor(-0.3672), tensor(-0.6...","[tensor(-0.0757), tensor(-0.0458), tensor(-0.0..."
4,5,Father of the Bride Part II (1995),Comedy,"[tensor(-0.7473), tensor(-0.8992), tensor(2.59...","[tensor(-0.0932), tensor(-0.1122), tensor(0.32..."


In [222]:
import random

25478

In [264]:
def get_similar(movie_df, movie_id=None, k=20):
    if movie_id is None:
        movie_id = np.random.randint(1, movie_df.shape[0])        
    query = movie_df[movie_df['movieId'] == movie_id].iloc[0]
    query_emb = query['norm_embedding']
    sim_scores = (emb @ query_emb.view(1,-1).T)
    sorted_idx = torch.argsort(sim_scores.view(-1),descending=True)
    top_k_idx = sorted_idx[1:k+1]
    res = movie_df.iloc[top_k_idx]
    titles = res['title'].values
    genres = res['genres'].values
    scores = sim_scores[top_k_idx]
    return {
        "query": query[['movieId', 'title', 'genres']].to_dict(),
        'result': [[t,g,s] for t,g,s in zip(titles, genres, scores)]
    }

In [265]:
get_similar(movie_df, movie_id=192389, k=10)

{'query': {'movieId': 192389,
  'title': 'Venom (2018)',
  'genres': 'Action|Horror|Sci-Fi|Thriller'},
 'result': [['Cynthia (1947)', 'Comedy|Drama', tensor([0.6378])],
  ['The Sandpit Generals (1971)', '(no genres listed)', tensor([0.6098])],
  ['Mask of Dust (1954)', 'Drama', tensor([0.6071])],
  ['Twist Around The Clock (1961)', '(no genres listed)', tensor([0.6045])],
  ['Claudine (1974)', 'Drama', tensor([0.6001])],
  ['Second Fiddle to a Steel Guitar (1966)', 'Comedy', tensor([0.5996])],
  ["The Last Sharknado: It's About Time (2018)",
   'Action|Adventure|Comedy|Fantasy|Sci-Fi|Thriller',
   tensor([0.5921])],
  ['Reivers, The (1969)', 'Comedy|Drama', tensor([0.5906])],
  ['Rack, The (1956)', 'Drama|War', tensor([0.5825])],
  ['Wonderful Days (a.k.a. Sky Blue) (2003)',
   'Animation|Sci-Fi',
   tensor([0.5800])]]}

In [263]:
movie_df.iloc[55643]

movieId                                                      192389
title                                                  Venom (2018)
genres                                Action|Horror|Sci-Fi|Thriller
embedding         [tensor(0.5151), tensor(-0.9153), tensor(-0.55...
norm_embedding    [tensor(0.0641), tensor(-0.1138), tensor(-0.06...
Name: 55643, dtype: object

In [260]:
movie_df[movie_df['title'].map(lambda x: "venom" in x.lower())]

Unnamed: 0,movieId,title,genres,embedding,norm_embedding
6033,6145,Venom (1982),Horror|Thriller,"[tensor(1.1496), tensor(-0.7547), tensor(-0.71...","[tensor(0.1434), tensor(-0.0941), tensor(-0.08..."
8790,26399,Five Deadly Venoms (1978),Action,"[tensor(-0.0652), tensor(-0.6690), tensor(-0.4...","[tensor(-0.0081), tensor(-0.0835), tensor(-0.0..."
10205,36531,Venom (2005),Horror|Thriller,"[tensor(-0.0295), tensor(-0.4241), tensor(1.43...","[tensor(-0.0037), tensor(-0.0528), tensor(0.17..."
22599,115727,Crippled Avengers (Can que) (Return of the 5 D...,Action|Adventure,"[tensor(0.3127), tensor(-0.1394), tensor(-0.73...","[tensor(0.0389), tensor(-0.0173), tensor(-0.09..."
32508,141070,Venom (1966),Drama,"[tensor(2.8667), tensor(0.3400), tensor(1.6905...","[tensor(0.3571), tensor(0.0424), tensor(0.2106..."
37355,152467,Venomous (2002),Horror|Sci-Fi,"[tensor(-0.9478), tensor(-1.0008), tensor(1.29...","[tensor(-0.1186), tensor(-0.1252), tensor(0.16..."
37357,152471,Silent Venom (2009),Action|Adventure|Horror|Thriller,"[tensor(-0.9665), tensor(1.5667), tensor(0.971...","[tensor(-0.1205), tensor(0.1954), tensor(0.121..."
55072,191077,Venom and Eternity (1951),(no genres listed),"[tensor(0.8810), tensor(0.1856), tensor(0.1019...","[tensor(0.1101), tensor(0.0232), tensor(0.0127..."
55643,192389,Venom (2018),Action|Horror|Sci-Fi|Thriller,"[tensor(0.5151), tensor(-0.9153), tensor(-0.55...","[tensor(0.0641), tensor(-0.1138), tensor(-0.06..."
56211,193669,Venom Islands (2012),Documentary,"[tensor(-0.0993), tensor(-0.6386), tensor(-0.4...","[tensor(-0.0124), tensor(-0.0797), tensor(-0.0..."
