In [1]:
import pandas as pd
import nltk
from collections import Counter
from nltk.tokenize import word_tokenize
from nltk.util import ngrams
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /home/ma/ma_ma/ma_pbhattar/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
sessions_test=pd.read_csv("../../data/test/sessions_test_task1.csv")
sessions_train=pd.read_csv("../../data/train/sessions_train.csv")

In [3]:
def convert_to_text_list(df):
    texts = []
    for index, row in df.iterrows():
        text = str(row['prev_items'])[1:-1]
        text = text.replace("'", "")
        text = text.replace("\n", "")
        text = text + ' ' + str(row['next_item'])
        texts.append(text)
    return texts

def convert_to_token_list(df):
    tokens = []
    for index, row in df.iterrows():
        text = str(row['prev_items'])[1:-1]
        text = text.replace("'", "")
        text = text.replace("\n", "")
        tokens.extend(text.split())
        tokens.append(str(row['next_item']))
    return tokens

In [4]:
from collections import defaultdict

def create_ngram_model(texts, n):
    ngrams = defaultdict(lambda: defaultdict(int))

    # Tokenize the texts
    tokens = [word_tokenize(text) for text in texts]

    # Create n-grams
    for token_list in tokens:
        for i in range(len(token_list) - n):
            ngram = tuple(token_list[i:i + n])
            next_token = token_list[i + n]
            ngrams[ngram][next_token] += 1

    # Sort the next tokens by frequency
    for ngram, next_tokens in ngrams.items():
        ngrams[ngram] = dict(sorted(next_tokens.items(), key=lambda x: x[1], reverse=True)[:100])

    return ngrams

In [5]:
# Seperate the sessions data by locale
sessions_train_DE = sessions_train[sessions_train['locale'] == 'DE']
sessions_train_JP = sessions_train[sessions_train['locale'] == 'JP']
sessions_train_UK = sessions_train[sessions_train['locale'] == 'UK']

In [6]:
# # First, convert the sessions data into text lists
# texts_DE = convert_to_text_list(sessions_train_DE)
# texts_JP = convert_to_text_list(sessions_train_JP)
# texts_UK = convert_to_text_list(sessions_train_UK)

In [7]:
# Step 1: Create a dictionary of texts by locale
texts = {
    'DE': convert_to_text_list(sessions_train_DE),
    'JP': convert_to_text_list(sessions_train_JP),
    'UK': convert_to_text_list(sessions_train_UK),
}

In [8]:
# Generate n-gram models for each locale. We create as high as 5-grams
ngram_models = {'DE': [], 'JP': [], 'UK': []}

In [9]:
# Generate 5-grams for DE, JP, and UK
de_5gram_model = create_ngram_model(texts['DE'], 5)
jp_5gram_model = create_ngram_model(texts['JP'], 5)
uk_5gram_model = create_ngram_model(texts['UK'], 5)

In [10]:
# Generate 4-grams for DE, JP, and UK
de_4gram_model = create_ngram_model(texts['DE'], 4)
jp_4gram_model = create_ngram_model(texts['JP'], 4)
uk_4gram_model = create_ngram_model(texts['UK'], 4)

In [11]:
# Generate 3-grams for DE, JP, and UK
de_3gram_model = create_ngram_model(texts['DE'], 3)
jp_3gram_model = create_ngram_model(texts['JP'], 3)
uk_3gram_model = create_ngram_model(texts['UK'], 3)

In [12]:
# Generate 2-grams for DE, JP, and UK
de_2gram_model = create_ngram_model(texts['DE'], 2)
jp_2gram_model = create_ngram_model(texts['JP'], 2)
uk_2gram_model = create_ngram_model(texts['UK'], 2)

In [13]:
# Generate 1-grams for DE, JP, and UK
de_1gram_model = create_ngram_model(texts['DE'], 1)
jp_1gram_model = create_ngram_model(texts['JP'], 1)
uk_1gram_model = create_ngram_model(texts['UK'], 1)

In [14]:
# Add all n-gram models to ngram_models
for i in range(5, 0, -1):
    ngram_models['DE'].append(eval(f'de_{i}gram_model'))
    ngram_models['JP'].append(eval(f'jp_{i}gram_model'))
    ngram_models['UK'].append(eval(f'uk_{i}gram_model'))

In [15]:
# Create allgrams
all_grams_dict = {'DE': Counter(texts['DE']),
                  'JP': Counter(texts['JP']),
                  'UK': Counter(texts['UK'])}

In [21]:
texts['DE']

['B09W9FND7K B09JSPLN1M B09M7GY217',
 'B076THCGSG B007MO8IME B08MF65MLV B001B4TKA0 B001B4THSA',
 'B0B1LGXWDS B00AZYORS2 B0B1LGXWDS B00AZYORS2 B0B1LGXWDS B0B1LGXWDS B00AZYORS2 B0B1LGXWDS B00AZYORS2 B0767DTG2Q',
 'B09XMTWDVT B0B4MZZ8MB B0B7HZ2GWX B09XMTWDVT B0B4MZZ8MB B0B7HZ2GWX B0B71CHT1L B0B4R9NN4B',
 'B09Y5CSL3T B09Y5DPTXN B09FKD61R8 B0BGVBKWGZ',
 'B0749V8TC7 B0749W93VC B0749TX4YP B0749TX4YS',
 'B09SMK3R8H B01N4ND0F9 B08YNZT93Z',
 'B09B2W5S9R B09B2YFY6M B09B2WGPRB B097CX2V3L',
 'B01MQOR80Q B095HS8R62 B09B31WTVY B09B32SSDT',
 '3649625660 B07N3SNQW5 B099JZ9L9Y B07Q2CFPGH B099KCMQ92 B07Q2CFPGH 3848520974 B08R7G53T1',
 'B07H1GQB36 B08DTZ3PTY B08G4DFMNN',
 'B0927GXJPB B003AO3LS6 B09BD7P7XZ',
 'B07YPSZ566 B08G91WFQR B08G97TPH8 B08G96F4ZF B01H6JS4IC B0B42R2GKP B0B5344X16 B0B1J5T32Y',
 'B09ZTSZV49 B09HXP3R6T B0B5QR5DJC B09SFTDW9W B0B7RD1HQQ B09JGJ43YJ',
 'B07LFQPX5L B07LFPJTVF B07LFRP5SS',
 'B085QWM3KB B001BAAV5W B002GHKPQY B010VC4AG6 B00B2IOQHM B0773V43RV B000VVMY48 B000VVSWBM',
 'B01IKAEFJS

In [None]:
# # Generate n-gram models for each locale
# onegram_DE = create_ngram_model(convert_to_text_list(sessions_train_DE), 1)

In [None]:
# onegram_JP = create_ngram_model(convert_to_text_list(sessions_train_JP), 1)

In [None]:
# onegram_UK = create_ngram_model(convert_to_text_list(sessions_train_UK), 1)

In [None]:
# # Combine ngram models and all_grams into dictionaries for easy access
# ngram_models = {'DE': onegram_DE, 'JP': onegram_JP, 'UK': onegram_UK}

In [None]:
# all_grams_dict = {'DE': Counter(convert_to_token_list(sessions_train_DE)),
#                   'JP': Counter(convert_to_token_list(sessions_train_JP)),
#                   'UK': Counter(convert_to_token_list(sessions_train_UK))}

In [17]:
# Find the top 100 items for each locale
for locale in all_grams_dict:
    all_grams_dict[locale] = all_grams_dict[locale].most_common(100)

In [20]:
de_3gram_model

defaultdict(<function __main__.create_ngram_model.<locals>.<lambda>()>,
            {('B076THCGSG', 'B007MO8IME', 'B08MF65MLV'): {'B001B4TKA0': 1},
             ('B007MO8IME', 'B08MF65MLV', 'B001B4TKA0'): {'B001B4THSA': 1},
             ('B0B1LGXWDS', 'B00AZYORS2', 'B0B1LGXWDS'): {'B00AZYORS2': 2,
              'B0B1LGXWDS': 1},
             ('B00AZYORS2', 'B0B1LGXWDS', 'B00AZYORS2'): {'B0B1LGXWDS': 1,
              'B0767DTG2Q': 1,
              'B08XXLDH2N': 1,
              'B00AZYORS2': 1},
             ('B00AZYORS2', 'B0B1LGXWDS', 'B0B1LGXWDS'): {'B00AZYORS2': 1},
             ('B0B1LGXWDS', 'B0B1LGXWDS', 'B00AZYORS2'): {'B0B1LGXWDS': 1},
             ('B09XMTWDVT', 'B0B4MZZ8MB', 'B0B7HZ2GWX'): {'B09XMTWDVT': 42,
              'B0B71CHT1L': 11,
              'B0B4MZZ8MB': 9,
              'B0B71KXCSB': 8,
              'B09XMTMC9W': 8,
              'B0B71JHMTC': 6,
              'B0B7R6C9ZW': 4,
              'B09Z25LH2G': 4,
              'B0B71GSJ2R': 4,
              'B0B7R4ZB

In [19]:
de_2gram_model

defaultdict(<function __main__.create_ngram_model.<locals>.<lambda>()>,
            {('B09W9FND7K', 'B09JSPLN1M'): {'B09M7GY217': 1,
              'B09YS6S48R': 1,
              'B07KDC7PJH': 1,
              'B07MDZ2K4F': 1},
             ('B076THCGSG', 'B007MO8IME'): {'B08MF65MLV': 1},
             ('B007MO8IME', 'B08MF65MLV'): {'B001B4TKA0': 1},
             ('B08MF65MLV', 'B001B4TKA0'): {'B001B4THSA': 1},
             ('B0B1LGXWDS', 'B00AZYORS2'): {'B0B1LGXWDS': 3,
              'B0767DTG2Q': 1,
              'B004N9BSQE': 1,
              'B096Y4VBF2': 1,
              'B08XXLDH2N': 1,
              'B07JLJ4N44': 1,
              'B09D2XNFZ2': 1,
              'B00AZYORS2': 1},
             ('B00AZYORS2', 'B0B1LGXWDS'): {'B00AZYORS2': 4,
              'B0B1LGXWDS': 1,
              'B07NJBFWFM': 1,
              'B01DBUK56Y': 1,
              'B097GVPJ74': 1},
             ('B0B1LGXWDS', 'B0B1LGXWDS'): {'B00AZYORS2': 1},
             ('B09XMTWDVT', 'B0B4MZZ8MB'): {'B0B7HZ2GWX': 15

In [18]:
all_grams_dict

{'DE': [('B00FYHVWEW B00FYHVXBE B00FYHVXQY', 25),
  ('B00J0BA7AI B00J0BA8A2 B00J0BAB9K', 24),
  ('B0166QJF98 B0166QJISG B08JQKXXKC', 23),
  ('B00HFDYUEW B00IFHWUQ2 B00HFDYUV0', 22),
  ('B00PWKKZJ4 B00PWKL390 B00PWKL868', 22),
  ('B07WFQP2NB B07WGQPLSC B07WGQMZQ7', 22),
  ('B00HDPPTDS B00HUYY54Q B00HWVTAQ0', 22),
  ('B078RSDRBN B078RVB22Z B08JH6X5JY', 21),
  ('B00G0OFK46 B00G0OFKKU B00HWWHD4K', 21),
  ('B00FYUQ9IS B00FYUQC08 B00H39YASE', 21),
  ('B09KJ4XK97 B09KJ7Z6D2 B09KJ7ZRW5', 21),
  ('B0B27HDSXG B0B27TJQ5T B0B27W863Y', 21),
  ('B07NDNJYYM B07NDPS4BF B07NDP7R6S', 20),
  ('B07Y9Y9PJ8 B07Y9ZVF36 B07YBBJD93', 20),
  ('B09K9MMGCX B09KFMJY3L B09KFRLW1D', 20),
  ('B00EXHJKT8 B00EXHJMQY B00FFIK5YS', 20),
  ('B00HXJQFXM B00I8W77OO B00I8W91Y8', 20),
  ('B01MRCF2Y0 B01MTFHO95 B01MRCGW3C', 19),
  ('B0B8P8WT53 B0B8PG7FNN B0B8P9445G', 19),
  ('B00EVAU17W B00EVAU29E B00H3APMIA', 19),
  ('B0766FMZMB B0766LWYZ5 B0766H2WVH', 19),
  ('B00HX2J39W B00HX2J48C B00HX2J658', 19),
  ('B09T3VHDMG B09T3ZQLJ8 

In [16]:
sessions_test.head(3)

Unnamed: 0,prev_items,locale
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE


In [22]:
# Make predictions based on locale
for index, row in sessions_test.iterrows():
    predictions = []
    text = str(row['prev_items'])[1:-1]
    text = text.replace("'", "")
    text = text.replace("\n", "")
    tokens = word_tokenize(text)
    locale = row['locale']

    for n in range(5, 0, -1):  # start with 4-gram model and go down to 1-gram
        if len(tokens) >= n:
            gram = tuple(tokens[-n:])
            for k, v in ngram_models[locale][n-1][gram].items():
                predictions.append(k)
                if len(predictions) == 100:
                    break
            if len(predictions) == 100:
                break
    if len(predictions) < 100:  # fill up with popular items if less than 100 predictions
        for k, v in all_grams_dict[locale]:
            if k not in predictions:  # don't repeat predictions
                predictions.append(k)
                if len(predictions) == 100:
                    break
    sessions_test.at[index, 'next_item_prediction'] = str(predictions)

In [23]:
sessions_test.head(15)

Unnamed: 0,prev_items,locale,next_item_prediction
0,['B08V12CT4C' 'B08V1KXBQD' 'B01BVG1XJS' 'B09VC...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
1,['B00R9R5ND6' 'B00R9RZ9ZS' 'B00R9RZ9ZS'],DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
2,['B07YSRXJD3' 'B07G7Q5N6G' 'B08C9Q7QVK' 'B07G7...,DE,"['B08C9Q7QVK', 'B07G7Q5N6G', 'B0BF585GQ7', 'B0..."
3,['B08KQBYV43' '3955350843' '3955350843' '39553...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
4,['B09FPTCWMC' 'B09FPTQP68' 'B08HMRY8NG' 'B08TB...,DE,"['B09J8T6TTH', 'B09J8TWRV3', 'B09J8V18FL', 'B0..."
5,['B0BHT75TPQ' 'B0BHT7X2R6' 'B0BK5VMHND' 'B0BHT...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
6,['B071P9DVF6' 'B07BGHDRZH' 'B09S37TD4N'],DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
7,['B0B8D1V4QW' 'B0813KJ832' 'B099XL3VS4' 'B09V1...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
8,['B0B3BZFMCH' 'B0B3BW437K' 'B0B3C5P8N8' 'B0B3C...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."
9,['B08F9GMLXM' 'B0B8D4CWZ4' 'B08L9CZ7BW' 'B08FB...,DE,"['B00FYHVWEW B00FYHVXBE B00FYHVXQY', 'B00J0BA7..."


In [None]:
sessions_test.iloc[0].next_item_prediction

In [24]:
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=str(row['next_item_prediction'])[1:-1]

In [None]:
sessions_test.iloc[0].next_item_prediction

In [25]:
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=row['next_item_prediction'].replace("'","")

In [26]:
sessions_test.iloc[0].next_item_prediction

'B00FYHVWEW B00FYHVXBE B00FYHVXQY, B00J0BA7AI B00J0BA8A2 B00J0BAB9K, B0166QJF98 B0166QJISG B08JQKXXKC, B00HFDYUEW B00IFHWUQ2 B00HFDYUV0, B00PWKKZJ4 B00PWKL390 B00PWKL868, B07WFQP2NB B07WGQPLSC B07WGQMZQ7, B00HDPPTDS B00HUYY54Q B00HWVTAQ0, B078RSDRBN B078RVB22Z B08JH6X5JY, B00G0OFK46 B00G0OFKKU B00HWWHD4K, B00FYUQ9IS B00FYUQC08 B00H39YASE, B09KJ4XK97 B09KJ7Z6D2 B09KJ7ZRW5, B0B27HDSXG B0B27TJQ5T B0B27W863Y, B07NDNJYYM B07NDPS4BF B07NDP7R6S, B07Y9Y9PJ8 B07Y9ZVF36 B07YBBJD93, B09K9MMGCX B09KFMJY3L B09KFRLW1D, B00EXHJKT8 B00EXHJMQY B00FFIK5YS, B00HXJQFXM B00I8W77OO B00I8W91Y8, B01MRCF2Y0 B01MTFHO95 B01MRCGW3C, B0B8P8WT53 B0B8PG7FNN B0B8P9445G, B00EVAU17W B00EVAU29E B00H3APMIA, B0766FMZMB B0766LWYZ5 B0766H2WVH, B00HX2J39W B00HX2J48C B00HX2J658, B09T3VHDMG B09T3ZQLJ8 B09T3VTXRM, B0711XR9DD B072C2LYJN B071744SLH, B00GWUSGU8 B00GWUSE1O B07BC647XF, B00GWUSE1O B00GWUSGU8 B00OGU88KE, B09GKQT8SD B09GKT4Q3W B09GKTFJ3W, B00FZ29TNM B00FZ29U04 B00H3AO2YA, B00GWUSE1O B00GWUSGU8 B00GWUSDII, B00JO9JNHY B0

In [None]:
sessions_test.head(2)

In [27]:
sessions_test.drop('prev_items', axis=1, inplace=True)

In [28]:
# Reorganize next_item_prediction column to match submission format
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=row['next_item_prediction'].split(', ')

In [None]:
sessions_test

In [None]:
sessions_test.iloc[0].next_item_prediction

In [29]:
sessions_test.head(2)

Unnamed: 0,locale,next_item_prediction
0,DE,"[B00FYHVWEW B00FYHVXBE B00FYHVXQY, B00J0BA7AI ..."
1,DE,"[B00FYHVWEW B00FYHVXBE B00FYHVXQY, B00J0BA7AI ..."


In [30]:
'''
locale	next_item_prediction
0	DE	[B00FYHVWEW B00FYHVXBE B00FYHVXQY, B00J0BA7AI ...
1	DE	[B00FYHVWEW B00FYHVXBE B00FYHVXQY, B00J0BA7AI ...
'''
# Remove all existing commas from next_item_prediction column
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=str(row['next_item_prediction']).replace(",","")

In [31]:
sessions_test.head(2)

Unnamed: 0,locale,next_item_prediction
0,DE,['B00FYHVWEW B00FYHVXBE B00FYHVXQY' 'B00J0BA7A...
1,DE,['B00FYHVWEW B00FYHVXBE B00FYHVXQY' 'B00J0BA7A...


In [None]:
output_path = '../../outputs/' + 'task1_predictions_adv.parquet'

import pyarrow.parquet as pq
import pyarrow as pa

# Save predictions to parquet
table = pa.Table.from_pandas(sessions_test)
pq.write_table(table, output_path, compression='gzip')