### Import

In [1]:
from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

device = torch.device("cuda")

### Config

In [2]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 365
PRED_LEN = 100

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
}

# Model
batch_size = 32
nhead = 4
dropout = 0.1
is_identical = False

d_model = {"encoder":64, "decoder":32}
d_ff = {"encoder":64, "decoder":32}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"target": 0.25, "temporal":0.25, "cat":0.25}

# Data

### Raw data

In [3]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [4]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]

data_info = DataInfo(modality_info, processing_info)

In [5]:
train_dataset = Dataset(df_train, data_info, remain_rto)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 2442.10it/s]


sales torch.Size([2, 365, 1])
sales_remain_idx torch.Size([2, 91])
sales_masked_idx torch.Size([2, 274])
sales_revert_idx torch.Size([2, 365])
sales_remain_padding_mask torch.Size([2, 91])
sales_masked_padding_mask torch.Size([2, 274])
sales_revert_padding_mask torch.Size([2, 365])
day torch.Size([2, 365])
day_remain_idx torch.Size([2, 91])
day_masked_idx torch.Size([2, 274])
day_revert_idx torch.Size([2, 365])
day_remain_padding_mask torch.Size([2, 91])
day_masked_padding_mask torch.Size([2, 274])
day_revert_padding_mask torch.Size([2, 365])
dow torch.Size([2, 365])
dow_remain_idx torch.Size([2, 91])
dow_masked_idx torch.Size([2, 274])
dow_revert_idx torch.Size([2, 365])
dow_remain_padding_mask torch.Size([2, 91])
dow_masked_padding_mask torch.Size([2, 274])
dow_revert_padding_mask torch.Size([2, 365])
month torch.Size([2, 365])
month_remain_idx torch.Size([2, 91])
month_masked_idx torch.Size([2, 274])
month_revert_idx torch.Size([2, 365])
month_remain_padding_mask torch.Size([2, 91])

# Train

In [124]:
class Embedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.embedding = torch.nn.Linear(1, d_model)
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.embedding = torch.nn.Embedding(num_cls, d_model)
    
    def forward(self, data):
        return self.embedding(data)

### Model

In [125]:
class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict,
                d_model, num_layers, nhead, d_ff, dropout):
        super().__init__()
        self.data_info = data_info
        self.temporal_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        
        # 1. Embedding
        self.embedding_dict = self.init_process(self.temporal_cols, Embedding, **{"data_info":data_info, "label_encoder_dict":label_encoder_dict, "d_model":d_model["encoder"]})

    def forward(self, data_input, remain_rto, device):
        # 0. To gpu
        data_dict, idx_dict, padding_mask_dict = self.to_gpu(data_input, device)

        # 1. Embedding
        embedding_dict = self.apply_process(data_dict, self.temporal_cols, self.embedding_dict)
        
        return embedding_dict
    
    def to_gpu(self, data_input, device):
        data_dict, idx_dict, padding_mask_dict = {}, {}, {}

        for col in data.keys():
            if col in self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]:
                data_dict[col] = data_input[col].to(device)
            if col.endswith("idx"):
                idx_dict[col] = data_input[col].to(device)
            if col.endswith("padding_mask"):
                padding_mask_dict[col] = data_input[col].to(device)
        
        return data_dict, idx_dict, padding_mask_dict

    def init_process(self, target_cols, mod, **kwargs):
        result_dict = {}
        for col in target_cols:
            result_dict[col] = Embedding(col, **kwargs)
        return torch.nn.ModuleDict(result_dict)
    
    def apply_process(self, data, target_cols, mod, **kwargs):
        result_dict = {}
        for col in target_cols:
            result_dict[col] = mod[col](data[col], **kwargs)
        return result_dict


In [126]:
model = Transformer(data_info, train_dataset.label_encoder_dict,
                        d_model, num_layers, nhead, d_ff, dropout)
model.to(device)
summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

---------------------------------------------------------------------------------------
   Parent Layers       Layer (type)        Output Shape         Param #     Tr. Param #
Total params: 0
Trainable params: 0
Non-trainable params: 0
---------------------------------------------------------------------------------------


