In [1]:
import pandas as pd
import nltk
from collections import Counter
from nltk.tokenize import word_tokenize
from nltk.util import ngrams
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /home/ma/ma_ma/ma_pbhattar/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
sessions_test=pd.read_csv("../../data/test/sessions_test_task1_phase2.csv")
sessions_train=pd.read_csv("../../data/train/sessions_train.csv")

In [3]:
def convert_to_text_list(df):
    texts = []
    for index, row in df.iterrows():
        text = str(row['prev_items'])[1:-1]
        text = text.replace("'", "")
        text = text.replace("\n", "")
        text = text + ' ' + str(row['next_item'])
        texts.append(text)
    return texts

def convert_to_token_list(df):
    tokens = []
    for index, row in df.iterrows():
        text = str(row['prev_items'])[1:-1]
        text = text.replace("'", "")
        text = text.replace("\n", "")
        tokens.extend(text.split())
        tokens.append(str(row['next_item']))
    return tokens

In [4]:
from collections import defaultdict

def create_ngram_model(texts, n):
    ngrams = defaultdict(lambda: defaultdict(int))

    # Tokenize the texts
    tokens = [word_tokenize(text) for text in texts]

    # Create n-grams
    for token_list in tokens:
        for i in range(len(token_list) - n):
            ngram = tuple(token_list[i:i + n])
            next_token = token_list[i + n]
            ngrams[ngram][next_token] += 1

    # Sort the next tokens by frequency
    for ngram, next_tokens in ngrams.items():
        ngrams[ngram] = dict(sorted(next_tokens.items(), key=lambda x: x[1], reverse=True)[:100])

    return ngrams

In [5]:
# Seperate the sessions data by locale
sessions_train_DE = sessions_train[sessions_train['locale'] == 'DE']
sessions_train_JP = sessions_train[sessions_train['locale'] == 'JP']
sessions_train_UK = sessions_train[sessions_train['locale'] == 'UK']

In [6]:
# Generate n-gram models for each locale
onegram_DE = create_ngram_model(convert_to_text_list(sessions_train_DE), 1)

In [7]:
onegram_JP = create_ngram_model(convert_to_text_list(sessions_train_JP), 1)

In [8]:
onegram_UK = create_ngram_model(convert_to_text_list(sessions_train_UK), 1)

In [9]:
# Combine ngram models and all_grams into dictionaries for easy access
ngram_models = {'DE': onegram_DE, 'JP': onegram_JP, 'UK': onegram_UK}

In [10]:
all_grams_dict = {'DE': Counter(convert_to_token_list(sessions_train_DE)),
                  'JP': Counter(convert_to_token_list(sessions_train_JP)),
                  'UK': Counter(convert_to_token_list(sessions_train_UK))}

In [11]:
# Find the top 100 items for each locale
for locale in all_grams_dict:
    all_grams_dict[locale] = all_grams_dict[locale].most_common(100)

In [12]:
# Make predictions based on locale
for index, row in sessions_test.iterrows():
    predictions = []
    text = str(row['prev_items'])[1:-1]
    text = text.replace("'", "")
    text = text.replace("\n", "")
    tokens = word_tokenize(text)
    gram = [tokens[-1]]
    gram = tuple(gram)
    locale = row['locale']
    i = 1
    for k, v in ngram_models[locale][gram].items():
        predictions.append(k)
        if i == 100:
            break
        i += 1
    if i < 100:
        for k, v in all_grams_dict[locale]:
            predictions.append(k)
            if i == 100:
                break
            i += 1
    # Filter out items that are already in prev_items
    predictions = [x for x in predictions if x not in tokens]
    # If there are less than 100 predictions, fill the rest with the most popular items
    if len(predictions) < 100:
        for k, v in all_grams_dict[locale]:
            # Append k only if it is not already in predictions
            if k not in predictions:
                predictions.append(k)
            if len(predictions) == 100:
                break
    sessions_test.at[index, 'next_item_prediction'] = str(predictions)

In [13]:
sessions_test.head(2)

Unnamed: 0,prev_items,locale,next_item_prediction
0,['B087VLP2RT' 'B09BRQSHYH' 'B099KW4ZLV'],DE,"['B07SDFLVKD', 'B091CK241X', 'B0BGC82WVW', 'B0..."
1,['B08XW4W667' 'B096VMCJYF' 'B096VMCJYF'],DE,"['B004P4QFJM', 'B084CB7GX9', 'B09YD8XV6M', 'B0..."


In [14]:
sessions_test.iloc[0].next_item_prediction

"['B07SDFLVKD', 'B091CK241X', 'B0BGC82WVW', 'B0B9GJLV2D', 'B093X59B31', 'B07SR4R8K1', 'B08SRMPBRF', 'B09QBR7XJD', 'B08JW624NN', 'B00ZQW91DE', 'B087VMGP5G', 'B0977MTK65', 'B0BDML9477', 'B08GWS298V', 'B0B1MPZWJG', 'B09NQGVSPD', 'B0922JX27X', 'B08H93ZRK9', 'B0BFJGXWDV', 'B07R4PN3MV', 'B0B2Q4ZRDW', 'B00GWUSE1O', 'B09C7BRP5Y', 'B09MTWFCLY', 'B07CZ4DLCP', 'B00GWUSGU8', 'B0B2Q2VVGP', 'B095Z1QGWJ', 'B07CNRN566', 'B0033Q5KU8', 'B0B61MQD58', 'B08LSNJQ1N', 'B086NF5PMC', 'B07XKBLL8F', 'B08LSL593L', 'B0B8D4CWZ4', 'B07QPV9Z7X', 'B08LJRYBP2', 'B09G9B4SH9', 'B09C6RTP2S', 'B0935DN1BN', 'B0B8NNHR5N', 'B08L5TKXW3', 'B07YSRXJD3', 'B0935JRJ59', 'B088FSHMQ3', 'B07JG9TFSB', 'B0971DDW5C', 'B00FZWPO5Y', 'B00LN803LE', 'B099DP3617', 'B09MTVJX9K', 'B008TLGIA8', 'B09XMTWDVT', 'B07YPSZ566', 'B07CRT1KJ7', 'B08C5DR9GR', 'B0B4MZZ8MB', 'B06Y12PQJ8', 'B07GDVG5FQ', 'B0B4BJG9L4', 'B09QWH3T52', 'B0B34QYWDK', 'B07GH48Q2G', 'B00NTCH52W', 'B0936KTSPV', 'B09TTZZWW6', 'B09DL9HP8W', 'B074X4W71C', 'B004605SE8', 'B00K69H85U', 'B00

In [15]:
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=str(row['next_item_prediction'])[1:-1]

In [16]:
sessions_test.iloc[0].next_item_prediction

"'B07SDFLVKD', 'B091CK241X', 'B0BGC82WVW', 'B0B9GJLV2D', 'B093X59B31', 'B07SR4R8K1', 'B08SRMPBRF', 'B09QBR7XJD', 'B08JW624NN', 'B00ZQW91DE', 'B087VMGP5G', 'B0977MTK65', 'B0BDML9477', 'B08GWS298V', 'B0B1MPZWJG', 'B09NQGVSPD', 'B0922JX27X', 'B08H93ZRK9', 'B0BFJGXWDV', 'B07R4PN3MV', 'B0B2Q4ZRDW', 'B00GWUSE1O', 'B09C7BRP5Y', 'B09MTWFCLY', 'B07CZ4DLCP', 'B00GWUSGU8', 'B0B2Q2VVGP', 'B095Z1QGWJ', 'B07CNRN566', 'B0033Q5KU8', 'B0B61MQD58', 'B08LSNJQ1N', 'B086NF5PMC', 'B07XKBLL8F', 'B08LSL593L', 'B0B8D4CWZ4', 'B07QPV9Z7X', 'B08LJRYBP2', 'B09G9B4SH9', 'B09C6RTP2S', 'B0935DN1BN', 'B0B8NNHR5N', 'B08L5TKXW3', 'B07YSRXJD3', 'B0935JRJ59', 'B088FSHMQ3', 'B07JG9TFSB', 'B0971DDW5C', 'B00FZWPO5Y', 'B00LN803LE', 'B099DP3617', 'B09MTVJX9K', 'B008TLGIA8', 'B09XMTWDVT', 'B07YPSZ566', 'B07CRT1KJ7', 'B08C5DR9GR', 'B0B4MZZ8MB', 'B06Y12PQJ8', 'B07GDVG5FQ', 'B0B4BJG9L4', 'B09QWH3T52', 'B0B34QYWDK', 'B07GH48Q2G', 'B00NTCH52W', 'B0936KTSPV', 'B09TTZZWW6', 'B09DL9HP8W', 'B074X4W71C', 'B004605SE8', 'B00K69H85U', 'B000

In [17]:
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=row['next_item_prediction'].replace("'","")

In [18]:
sessions_test.iloc[0].next_item_prediction

'B07SDFLVKD, B091CK241X, B0BGC82WVW, B0B9GJLV2D, B093X59B31, B07SR4R8K1, B08SRMPBRF, B09QBR7XJD, B08JW624NN, B00ZQW91DE, B087VMGP5G, B0977MTK65, B0BDML9477, B08GWS298V, B0B1MPZWJG, B09NQGVSPD, B0922JX27X, B08H93ZRK9, B0BFJGXWDV, B07R4PN3MV, B0B2Q4ZRDW, B00GWUSE1O, B09C7BRP5Y, B09MTWFCLY, B07CZ4DLCP, B00GWUSGU8, B0B2Q2VVGP, B095Z1QGWJ, B07CNRN566, B0033Q5KU8, B0B61MQD58, B08LSNJQ1N, B086NF5PMC, B07XKBLL8F, B08LSL593L, B0B8D4CWZ4, B07QPV9Z7X, B08LJRYBP2, B09G9B4SH9, B09C6RTP2S, B0935DN1BN, B0B8NNHR5N, B08L5TKXW3, B07YSRXJD3, B0935JRJ59, B088FSHMQ3, B07JG9TFSB, B0971DDW5C, B00FZWPO5Y, B00LN803LE, B099DP3617, B09MTVJX9K, B008TLGIA8, B09XMTWDVT, B07YPSZ566, B07CRT1KJ7, B08C5DR9GR, B0B4MZZ8MB, B06Y12PQJ8, B07GDVG5FQ, B0B4BJG9L4, B09QWH3T52, B0B34QYWDK, B07GH48Q2G, B00NTCH52W, B0936KTSPV, B09TTZZWW6, B09DL9HP8W, B074X4W71C, B004605SE8, B00K69H85U, B00006JCUB, B081FWVSG8, B07DRKMWYX, B09QLW7HS2, B0BDJ47W5B, B00CWNMV4G, B0892LX5VS, B07F16BD5N, B07JM21QHM, B09QFPZ9B7, B0B466H784, B0B7HZ2GWX, B09

In [19]:
sessions_test.head(2)

Unnamed: 0,prev_items,locale,next_item_prediction
0,['B087VLP2RT' 'B09BRQSHYH' 'B099KW4ZLV'],DE,"B07SDFLVKD, B091CK241X, B0BGC82WVW, B0B9GJLV2D..."
1,['B08XW4W667' 'B096VMCJYF' 'B096VMCJYF'],DE,"B004P4QFJM, B084CB7GX9, B09YD8XV6M, B004P4OF1C..."


In [20]:
sessions_test.drop('prev_items', axis=1, inplace=True)

In [21]:
# Reorganize next_item_prediction column to match submission format
for index, row in sessions_test.iterrows():
    sessions_test.at[index, 'next_item_prediction']=row['next_item_prediction'].split(', ')

In [22]:
sessions_test

Unnamed: 0,locale,next_item_prediction
0,DE,"[B07SDFLVKD, B091CK241X, B0BGC82WVW, B0B9GJLV2..."
1,DE,"[B004P4QFJM, B084CB7GX9, B09YD8XV6M, B004P4OF1..."
2,DE,"[B09Z4PZQBF, B01GS8K962, B08LLF9M11, B08KHJN9H..."
3,DE,"[B09GKJ9RRJ, B07X8MW1G1, B07QQZD49D, B0BDML947..."
4,DE,"[B0B2JY9THB, B08YK8FQJ8, B09C89S7WG, B08R9PTZ5..."
...,...,...
316967,UK,"[B07GKP2LCF, B07GKYSHB4, B006DDGCI2, B016RAAUE..."
316968,UK,"[B00M35Y326, B08B395NHL, B08CN3G4N9, B07N8QY3Y..."
316969,UK,"[B08VDHH6QF, B08VD5DC5L, B08VDSL596, B089TQLLC..."
316970,UK,"[B089CZWB4C, B08W2JJZBM, B08T1ZJYHV, B09WCQYGX..."


In [23]:
sessions_test.iloc[0].next_item_prediction

['B07SDFLVKD',
 'B091CK241X',
 'B0BGC82WVW',
 'B0B9GJLV2D',
 'B093X59B31',
 'B07SR4R8K1',
 'B08SRMPBRF',
 'B09QBR7XJD',
 'B08JW624NN',
 'B00ZQW91DE',
 'B087VMGP5G',
 'B0977MTK65',
 'B0BDML9477',
 'B08GWS298V',
 'B0B1MPZWJG',
 'B09NQGVSPD',
 'B0922JX27X',
 'B08H93ZRK9',
 'B0BFJGXWDV',
 'B07R4PN3MV',
 'B0B2Q4ZRDW',
 'B00GWUSE1O',
 'B09C7BRP5Y',
 'B09MTWFCLY',
 'B07CZ4DLCP',
 'B00GWUSGU8',
 'B0B2Q2VVGP',
 'B095Z1QGWJ',
 'B07CNRN566',
 'B0033Q5KU8',
 'B0B61MQD58',
 'B08LSNJQ1N',
 'B086NF5PMC',
 'B07XKBLL8F',
 'B08LSL593L',
 'B0B8D4CWZ4',
 'B07QPV9Z7X',
 'B08LJRYBP2',
 'B09G9B4SH9',
 'B09C6RTP2S',
 'B0935DN1BN',
 'B0B8NNHR5N',
 'B08L5TKXW3',
 'B07YSRXJD3',
 'B0935JRJ59',
 'B088FSHMQ3',
 'B07JG9TFSB',
 'B0971DDW5C',
 'B00FZWPO5Y',
 'B00LN803LE',
 'B099DP3617',
 'B09MTVJX9K',
 'B008TLGIA8',
 'B09XMTWDVT',
 'B07YPSZ566',
 'B07CRT1KJ7',
 'B08C5DR9GR',
 'B0B4MZZ8MB',
 'B06Y12PQJ8',
 'B07GDVG5FQ',
 'B0B4BJG9L4',
 'B09QWH3T52',
 'B0B34QYWDK',
 'B07GH48Q2G',
 'B00NTCH52W',
 'B0936KTSPV',
 'B09TTZZW

In [24]:
sessions_test.head(2)

Unnamed: 0,locale,next_item_prediction
0,DE,"[B07SDFLVKD, B091CK241X, B0BGC82WVW, B0B9GJLV2..."
1,DE,"[B004P4QFJM, B084CB7GX9, B09YD8XV6M, B004P4OF1..."


In [25]:
output_path = '../../outputs/' + 'task1_predictions_phase2_ngram_improved.parquet'

import pyarrow.parquet as pq
import pyarrow as pa

# Save predictions to parquet
table = pa.Table.from_pandas(sessions_test)
pq.write_table(table, output_path, compression='gzip')