In [23]:
%load_ext autoreload
%autoreload 2

import re
from functools import partial

import numpy as np
import polars as pl
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

from mts_ml_cup import urls as u

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
url_cleaner = partial(
    u.clean_url,
    preprocessors=[
        u.decode_from_punycode,
        u.lower,
        u.replace_hyphens_with_points,
    ],
)

In [3]:
%%time
urls_popularity = (
    pl.read_parquet("../data/processed/sessions.pq", columns=["user_id", "url_host", "request_cnt"])
    .groupby("url_host")
    .agg(
        [
            pl.col("user_id").n_unique().alias("n_users"),
            pl.col("request_cnt").sum().alias("n_requests"),
            pl.col("url_host").count().alias("n_rows"),
        ]
    )
    .sort(["n_users", "n_rows", "n_requests", "url_host"], reverse=[True, True, True, False])
    .with_columns(pl.col("url_host").apply(url_cleaner).alias("url_cleaned"))
)

CPU times: user 1min 23s, sys: 23.8 s, total: 1min 46s
Wall time: 40.6 s


In [11]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return (
        torch.sum(token_embeddings * input_mask_expanded, 1) 
        / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    )


def url_to_emb(url: str, tokenizer, model) -> torch.Tensor:
    url_parts = url.split(".")
    encoded_input = tokenizer(url_parts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input["attention_mask"])

In [5]:
%%time
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
model = AutoModel.from_pretrained("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

Downloading (…)okenizer_config.json: 100%|█████| 480/480 [00:00<00:00, 56.6kB/s]
Downloading (…)lve/main/config.json: 100%|█████| 645/645 [00:00<00:00, 81.8kB/s]
Downloading (…)"tokenizer.json";: 100%|████| 9.08M/9.08M [00:00<00:00, 36.8MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████| 239/239 [00:00<00:00, 82.0kB/s]
Downloading (…)"pytorch_model.bin";: 100%|███| 471M/471M [00:16<00:00, 28.4MB/s]


CPU times: user 3.02 s, sys: 2.01 s, total: 5.02 s
Wall time: 29.1 s


In [14]:
url_embs = {}
for url in tqdm(urls_popularity["url_cleaned"].unique()):
    url_embs[url] = url_to_emb(url, tokenizer, model).numpy()[0]

100%|█████████████████████████████████| 199508/199508 [1:04:07<00:00, 51.85it/s]


In [21]:
url_embs = {url: emb[0] for url, emb in url_embs.items()}

In [None]:
np.save("../data/mini-lm/url-embs.npy", url_embs)

In [36]:
%%time
users_urls = (
    pl.read_parquet("../data/processed/sessions.pq", columns=["user_id", "url_host", "request_cnt"])
    .groupby(["user_id", "url_host"])
    .agg(pl.col("request_cnt").sum())
    .with_columns(pl.col("url_host").apply(url_cleaner))
    .groupby(["user_id", "url_host"])
    .agg(pl.col("request_cnt").sum())
    .sort(["user_id", "request_cnt", "url_host"], reverse=[False, True, False])
)

CPU times: user 5min 32s, sys: 33.3 s, total: 6min 5s
Wall time: 3min 26s


In [32]:
users_urls = users_urls.sort(["user_id", "request_cnt", "url_host"])

In [42]:
user_embs = {}
for user_id, urls in tqdm(users_urls.groupby("user_id"), total=users_urls["user_id"].n_unique()):
    total_requests = urls["request_cnt"].sum()
    user_emb = 0
    for _, url, requests in urls.iter_rows():
        user_emb += requests / total_requests * url_embs[url]
    user_embs[user_id] = user_emb

100%|█████████████████████████████████| 415317/415317 [04:36<00:00, 1500.51it/s]


In [44]:
np.save("../data/mini-lm/weighted-user-embs.npy", user_embs)

In [47]:
import pandas as pd

In [54]:
embs = pl.from_pandas(
    pd.Series(user_embs, name="mini_lm_embeddings")
    .reset_index(drop=False)
    .rename(columns={"index": "user_id"})
)

In [56]:
embs.write_parquet("../data/features/mini-lm/weighted_ems.pq")