In [None]:
%cd ../

In [2]:
import sys
from pathlib import Path
from collections import defaultdict

import polars as pl
from loguru import logger
from polars import DataFrame

In [None]:
logger.remove()
logger.add(sys.stderr, level="INFO")

# Read dataset

In [None]:
path = "data/raw/ml-1m/ml-1m.inter"

interactions_raw = pl.read_csv(path, separator='\t')
interactions_raw.head()

In [None]:
path = "data/raw/ml-1m/ml-1m.item"

items_raw = pl.read_csv(path, separator='\t')
items_raw.head()

# Process

## Process `interactions`

In [None]:
cols = {
    'user_id:token': 'user_id', 
    'item_id:token': 'item_id',
    'timestamp:float': 'timestamp'
}

inters = interactions_raw.select(cols.keys()).rename(cols)

inters.head()

### Apply `k`-core filtering

In [7]:
class KCoreFilter:
    def __init__(
        self,
        inters: DataFrame,
        k: int = 5,
        user_str: str = "user_id",
        item_str: str = "item_id",
    ):
        self.k = k
        self.inters = inters
        self.user_str, self.item_str = user_str, item_str

        self.adj = defaultdict(list)

        self._build_graph()

    def _conv_v2id(self, idx: int, is_user: bool = True) -> str:
        v_type = "u" if is_user else "i"
        return f"{idx}-{v_type}"

    def _conv_id2v(self, idx: str) -> tuple:
        v_id, v_type = idx.split("-")
        return int(v_id), v_type == "u"

    def __getitem__(self, key: str):
        return self.adj.get(key, None)

    def _build_graph(self):
        for row in self.inters.iter_rows(named=True):
            v_user = self._conv_v2id(row[self.user_str])
            v_item = self._conv_v2id(row[self.item_str], False)

            self.adj[v_user].append(v_item)
            self.adj[v_item].append(v_user)

    def filter(self):
        # Find invalid users and items

        users_invalid, items_invalid = [], []

        # Init
        deg = {k: len(v) for k, v in self.adj.items()}

        while True:
            # Find vertices having degree < K
            v_invalid = [k for k, v in deg.items() if v < self.k]
            if not v_invalid:
                break

            # Add invalid nodes to either `users_invalid` or `items_invalid`
            for v in v_invalid:
                v_id, v_type = self._conv_id2v(v)
                if v_type:  ## if vertex is user
                    users_invalid.append(v_id)
                else:
                    items_invalid.append(v_id)

            # Remove
            for v in v_invalid:
                for v_adj in self.adj[v]:
                    self.adj[v_adj].remove(v)

                del self.adj[v]

            # Update `deg`
            deg = {k: len(v) for k, v in self.adj.items()}

        logger.debug(users_invalid)
        logger.debug(items_invalid)

        # Remove invalid users and items from `inters`
        inters = self.inters.filter(
            (~pl.col(self.user_str).is_in(users_invalid))
            & (~pl.col(self.item_str).is_in(items_invalid))
        )

        return inters
    
inters = KCoreFilter(inters, k=5).filter()

## Create negative samples

In [None]:
N_NEG_TEST = 10
N_NEG_TRAIN = 100

users = inters.select('user_id').unique()
items = inters.select('item_id').unique()

inters_negative = (
    # List out all possible interactions (both positive and negative)
    users
    .join(items, how='cross')

    # Keep negative interactions
    .join(inters, on=['user_id', 'item_id'], how='anti')

    # Get N negative interactions for each user
    .with_columns(
        pl.col('item_id').rank(method='ordinal').over('user_id').alias('rank')
    )
    .filter(pl.col('rank') <= N_NEG_TEST + N_NEG_TRAIN)
    # .drop('rank')
)

inters_negative.head()

## Process `items`

In [None]:
cols = {
    'item_id:token': 'item_id',
    'genre:token_seq': 'genre',
}

items = items_raw.select(cols.keys()).rename(cols)
items.head()

In [None]:
genres = (
    items
    .with_columns(pl.col('genre').str.split(' '))
    .explode('genre')
    .select(pl.col('genre').unique())
    .with_row_index('genre_id')
)

genres.head()

In [11]:
items = (
    items
    .with_columns(pl.col('genre').str.split(' '))
    .explode('genre')
    .join(genres, on='genre')
    .group_by('item_id')
    .agg(
        pl.concat_list('genre_id').flatten()
    )
)

# Split train-val-test

With **Temporal LOO**

In [12]:
inters = (
    inters
    .with_columns(
        pl.col('timestamp').rank('min', descending=True).over('user_id').alias('rank')
    )
)

## Craft test split

In [None]:
inters_pos = (
    # Get inters from positive ones
    inters
    .filter(pl.col('rank') == 1)
    .with_columns(pl.lit(True).alias('is_positive'))
    .drop('rank', 'timestamp')
)

inters_neg = (
    inters_negative
    .filter(pl.col('rank') <= N_NEG_TEST)
    .with_columns(pl.lit(False).alias('is_positive'))
    .drop('rank')
)

inters_test = (
    pl.concat([inters_pos, inters_neg])
    .join(items, on='item_id', how='left')
)
inters_test.head()

## Craft train split

In [None]:
inters_pos = (
    # Get inters from positive ones
    inters
    .filter(pl.col('rank') > 1)
    .with_columns(pl.lit(True).alias('is_positive'))
    .drop('rank', 'timestamp')
)

inters_neg = (
    inters_negative
    .filter(pl.col('rank') > N_NEG_TEST)
    .with_columns(pl.lit(False).alias('is_positive'))
    .drop('rank')
)

inters_train = (
    pl.concat([inters_pos, inters_neg])
    .join(items, on='item_id', how='left')
)
inters_train.head()

# Save things

In [15]:
path = Path("data/processed/ml-1m/train_temporal-loo.parquet")
path.parent.mkdir(exist_ok=True, parents=True)

inters_train.write_parquet(path)

path = Path("data/processed/ml-1m/test_temporal-loo.parquet")
inters_test.write_parquet(path)

path = Path("data/interim/ml-1m_items.parquet")
path.parent.mkdir(exist_ok=True, parents=True)
items.write_parquet(path)