# Task 1

In [6]:
import os
import sys

# root path
ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Add the project root to the Python path
if ROOT not in sys.path:
    sys.path.append(ROOT)

import polars as pl
import pandas as pd

## Paths

In [10]:
DATA_PATH = os.path.join(ROOT, 'data')
RAW_DATA_PATH = os.path.join(DATA_PATH, 'raw')

USERS_RAW_PATH = os.path.join(RAW_DATA_PATH, 'user_batches')
USERS_CLEAN_PATH = os.path.join(DATA_PATH, 'processed', 'users.parquet')

TRAIN_PATH = os.path.join(DATA_PATH, 'raw', 'train.csv')

PRODUCTS_PATH = os.path.join(DATA_PATH, 'raw', 'products.pkl')

SUBMISSION_1_PATH = os.path.join(ROOT, 'predictions', 'example_predictions_1.json')

In [12]:
# Load sample submission json
import json
submission = json.load(open(SUBMISSION_1_PATH))

In [32]:
submission

{'target': {'query_1': {'partnumber': 17265},
  'query_2': {'user_id': 34572},
  'query_3': {'average_previous_visits': 5.52},
  'query_4': {'device_type': 23},
  'query_5': {'user_id': 123734},
  'query_6': {'unique_families': 2357},
  'query_7': {'1': 3, '2': 5, '3': 3, '4': 9, '5': 5, '6': 1}}}

## Query 1

Which product (partnumber) with color_id equal to 3 belongs to the lowest family code with a discount?

In [31]:
prods = pl.from_pandas(pd.read_pickle(PRODUCTS_PATH))

q1 = prods.sql("""
          SELECT partnumber
          FROM self
          WHERE color_id = 3
          AND discount = 1
          AND family = MIN(family)
          """).item()

submission['target']['query_1'] = {'partnumber': q1}

## Query 2

In the country where most users have made purchases totaling less than 500 (M), which is the user with:

- Lowest purchase frequency (F)
- Most recent purchase (R)
- Lowest user_id (tie-breaker)?

In [54]:
from src.data.loaders import PolarsLoader

loader = PolarsLoader(sampling=False, file_type='parquet')
users = loader.load_data(USERS_CLEAN_PATH)

In [71]:
country = (users
           .filter(pl.col('M') < 500)
           .group_by('country')
           .len()
           .sort(by='len', descending=True)
           .head(1)
           )['country'].item()
country

25

In [72]:
users.filter(pl.col('country') == country)

country,R,F,M,user_id
i64,i64,i64,f64,i64
25,30,0,0.0,430102
25,177,1,75.9,134198
25,32,61,37.694058,134207
25,74,86,11.64094,180365
25,79,5,30.283333,430101
…,…,…,…,…
25,155,9,17.423636,389294
25,62,16,45.104706,389292
25,8,74,36.052632,389298
25,15,26,20.201622,389296
