In [None]:
"""
Using market information for cross attention
→ Cannot use because
 - Too many products to attend (more than 1,000,000)
 - Needs 1 million^2 parameters

▶ Deprecated
"""

In [1]:
is_skip = True

# Data params
encode_len = 10
pred_len = 3
batch_size = 16

# Model params
d_model = 128
nhead = 4
d_ff = 256
dropout = 0.1
num_layers = 4

# Import

In [2]:
import numpy as np
import pandas as pd
from IPython.display import clear_output
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from statsmodels.tsa.seasonal import seasonal_decompose
from tqdm import tqdm

import torch
from pytorch_model_summary import summary

# Data

In [3]:
if not is_skip:
    # Read transaction
    df_trans = pd.read_csv("../HnM/transactions_train.csv", parse_dates=["t_dat"], dtype={"article_id":str})
    df_meta = pd.read_csv("../HnM/articles.csv", dtype={"article_id":str})
    df_raw = pd.merge(df_trans, df_meta, on="article_id")

In [4]:
def preprocess(data):
    data = data.copy()

    # Make sales
    data = data.groupby(["t_dat", "article_id"], as_index=False).agg(sales=("customer_id", "count"))
    data["idx"] = data.groupby("article_id").cumcount()

    # Pivot
    pd_ref = pd.pivot(data=data, index="idx", columns="article_id", values="sales")
    pd_ref = pd_ref.fillna(0)
    
    # Data to list
    data = pd_ref.reset_index().melt(id_vars="idx")
    data = data.groupby("article_id", as_index=False)[["value", "idx"]].agg(list)
    return data, pd_ref

if not is_skip:
    df_prep = df_raw[df_raw["index_name"]=="Ladieswear"].reset_index(drop=True)
    df_prep, pd_ref = preprocess(df_trans)
    df_prep.to_parquet("df_prep.pq")
    pd_ref.to_parquet("pd_ref.pq")

In [5]:
df_prep = pd.read_parquet("df_prep.pq")
pd_ref = pd.read_parquet("pd_ref.pq")

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df_prep, pd_ref, encode_len, pred_len):
        self.sales = torch.FloatTensor(np.array([np.array(i) for i in df_prep["value"]]))
        self.index = torch.FloatTensor(np.array([np.array(i) for i in df_prep["idx"]]))
        self.pd_ref, self.encode_len, self.pred_len = pd_ref, encode_len, pred_len

        self.sales = self.sales.unfold(-1, encode_len+pred_len, 1).reshape(-1, encode_len+pred_len)
        self.index = self.index.unfold(-1, encode_len+pred_len, 1).reshape(-1, encode_len+pred_len)

        # Filter non-zero
        nonzero_idx1 = torch.where(self.sales[:, :encode_len].sum(dim=-1) > 0)[0].numpy()
        nonzero_idx2 = torch.where(self.sales[:, -pred_len:].sum(dim=-1) > 0)[0].numpy()
        nonzero_idx = list(set(nonzero_idx1).intersection(set(nonzero_idx2)))
        self.sales = self.sales[nonzero_idx]
        self.index = self.index[nonzero_idx]

        # display(pd_ref[pd_ref.index.isin(idx[100].numpy())])
    
    def __len__(self):
        return self.sales.shape[0]
    
    def __getitem__(self, idx):
        sales = self.sales[idx].unsqueeze(-1)
        index = self.index[idx].numpy()
        ref = torch.FloatTensor(pd_ref[pd_ref.index.isin(index)].values.T)

        # Split into encoder_input and decoder_input
        encoder_marketstatus = ref[:, :encode_len]
        decoder_input_sales = sales[:encode_len]
        y = sales[-pred_len:]

        return {"encoder_marketstatus": encoder_marketstatus,
                "decoder_input_sales": decoder_input_sales,
                "y": y
                }
        

dataset = Dataset(df_prep, pd_ref, encode_len, pred_len)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

for data in dataloader:
    print("encoder_marketstatus", data["encoder_marketstatus"].shape)
    print("decoder_input_sales", data["decoder_input_sales"].shape)
    print("y", data["y"].shape)
    break

encoder_marketstatus torch.Size([16, 104547, 10])
decoder_input_sales torch.Size([16, 10, 1])
y torch.Size([16, 3, 1])


In [6]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 128)
        self.attn = torch.nn.MultiheadAttention(128, 4)

    def forward(self, x):
        res = self.linear(x)
        res, _ = self.attn(res, res, res)
        return res

arr = torch.rand(100000, 1)
model = Model()
summary(model, 
        arr,
        show_parent_layers=True, # show_hierarchical=True,
        print_summary=True
        )

: 