In [1]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

### Import

In [2]:
from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

device = torch.device("cuda")

### Config

In [10]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 365
PRED_LEN = 100

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "others": ["sales_channel_id", 
                "prod_name", "product_type_name", "product_group_name", 
                "graphical_appearance_name", "colour_group_name", 
                "perceived_colour_value_name", "perceived_colour_master_name", 
                "department_name", "index_name", "index_group_name", "section_name", "garment_group_name"]}

processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["sales_channel_id", 
                        "day", "dow", "month", "holiday",
                        "prod_name", "product_type_name", "product_group_name", 
                        "graphical_appearance_name", "colour_group_name", 
                        "perceived_colour_value_name", "perceived_colour_master_name", 
                        "department_name", "index_name", "index_group_name", "section_name", "garment_group_name"]}

# modality_info = {
#     "group": ["article_id", "sales_channel_id"],
#     "target": ["sales"],
#     "temporal": ["day", "dow", "month", "holiday", "price"],
#     "others": []}
# processing_info = {
#     "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
#     "embedding_cols": ["day", "dow", "month", "holiday"]}

# Model
batch_size = 32
dropout = 0.1

nhead = {"encoder":4, "decoder":4}
d_model = {"encoder":64, "decoder":32}
d_ff = {"encoder":64, "decoder":32}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"temporal":0.25, "others":0.25}

# Data

### Raw data

In [11]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [12]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]

data_info = DataInfo(modality_info, processing_info)

In [13]:
train_dataset = Dataset(df_train, data_info, remain_rto)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 3252.66it/s]

sales torch.Size([2, 365, 1])
sales_remain_idx torch.Size([2, 91])
sales_masked_idx torch.Size([2, 274])
sales_revert_idx torch.Size([2, 365])
sales_remain_padding_mask torch.Size([2, 91])
sales_masked_padding_mask torch.Size([2, 274])
sales_revert_padding_mask torch.Size([2, 365])
day torch.Size([2, 365])
day_remain_idx torch.Size([2, 91])
day_masked_idx torch.Size([2, 274])
day_revert_idx torch.Size([2, 365])
day_remain_padding_mask torch.Size([2, 91])
day_masked_padding_mask torch.Size([2, 274])
day_revert_padding_mask torch.Size([2, 365])
dow torch.Size([2, 365])
dow_remain_idx torch.Size([2, 91])
dow_masked_idx torch.Size([2, 274])
dow_revert_idx torch.Size([2, 365])
dow_remain_padding_mask torch.Size([2, 91])
dow_masked_padding_mask torch.Size([2, 274])
dow_revert_padding_mask torch.Size([2, 365])
month torch.Size([2, 365])
month_remain_idx torch.Size([2, 91])
month_masked_idx torch.Size([2, 274])
month_revert_idx torch.Size([2, 365])
month_remain_padding_mask torch.Size([2, 91])




# Architecture

### Helper

In [24]:
def get_positional_encoding(d_model, seq_len=1000):
    position = torch.arange(seq_len).reshape(-1,1)
    i = torch.arange(d_model)//2
    exp_term = 2*i/d_model
    div_term = torch.pow(10000, exp_term).reshape(1, -1)
    pos_encoded = position / div_term

    pos_encoded[:, 0::2] = torch.sin(pos_encoded[:, 0::2])
    pos_encoded[:, 1::2] = torch.cos(pos_encoded[:, 1::2])

    return pos_encoded

class NumericalEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear_embedding = torch.nn.Linear(1, d_model)
    
    def forward(self, data):
        return self.linear_embedding(data)

class CategoricalEmbedding(torch.nn.Module):
    def __init__(self, num_cls, d_model):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_cls, d_model)

    def forward(self, data):
        return self.embedding(data)

class TemporalRemain(torch.nn.Module):
    def __init__(self, col, pos_enc, global_token):
        super().__init__()
        self.col, self.pos_enc, self.global_token = col, pos_enc, global_token
    
    def forward(self, data, idx_dict):
        # Apply positional encoding
        pos_enc = self.pos_enc.unsqueeze(0).repeat(data.shape[0], 1, 1)
        data += pos_enc[:, 1:data.shape[1]+1, :]

        # Apply remain
        remain_idx = idx_dict[f"{self.col}_remain_idx"]
        remain_idx = remain_idx.unsqueeze(-1).repeat(1, 1, data.shape[-1])
        data = torch.gather(data, index=remain_idx, dim=1)

        # Apply global token
        global_token = self.global_token.unsqueeze(0).repeat(data.shape[0], 1, 1)
        global_token += pos_enc[:, :1, :]

        data = torch.cat([global_token, data], dim=1)

        return data

class ModalityEmbedding(torch.nn.Module):
    def __init__(self, modality_embedding, idx, is_others):
        super().__init__()
        self.modality_embedding, self.idx, self.is_others = modality_embedding, idx, is_others
        self.idx = idx
        self.is_others = is_others

    def forward(self, data, modality):
        if not self.is_others:
            modality = modality[self.idx].unsqueeze(0).repeat(data.shape[0], data.shape[1])
        else:
            modality = modality[self.idx].unsqueeze(0).repeat(data.shape[0], 1)
        modality_embedding = self.modality_embedding(modality)
        
        return data + modality_embedding

class OthersRemain(torch.nn.Module):
    def __init__(self, data_info, remain_rto, device):
        super().__init__()
        self.data_info = data_info
        self.num_remain = int(len(self.data_info.modality_info["others"]) * remain_rto)
        self.device = device

    def forward(self, data, idx_dict):
        temporal_dict = {key:val for key, val in data.items() if key in self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]}
        others_dict = {key:val for key, val in data.items() if key in self.data_info.modality_info["others"]}

        others_cat = torch.cat(list(others_dict.values()), dim=1)
        noise = torch.rand(others_cat.shape[:2]).to(device)
        shuffle = torch.argsort(noise, dim=-1)

        remain_idx = shuffle[:, :self.num_remain]
        masked_idx = shuffle[:, self.num_remain:]
        revert_idx = torch.argsort(shuffle, dim=-1)

        others_data = torch.gather(others_cat, index=remain_idx.unsqueeze(-1).repeat(1, 1, others_cat.shape[-1]), dim=1)
        idx_dict.update({"others_remain_idx": remain_idx, "others_masked_idx":masked_idx, "others_revert_idx":revert_idx})

        return temporal_dict, others_data, idx_dict


class Encoder(torch.nn.Module):
    def __init__(self, data_info, d_model, nhead, dim_feedforward, dropout, batch_first, activation, norm_first, num_layers):
        super().__init__()
        self.data_info = data_info
        self.encoder = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=batch_first, activation=activation, norm_first=norm_first), num_layers=num_layers)

    def forward(self, temporal_dict, others_data, padding_mask_dict, idx_dict):
        padding_mask = []

        # Temporal padding mask
        for col in temporal_dict.keys():
            remain_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]
            new_remain_padding_mask = torch.ones(remain_padding_mask.shape[0], remain_padding_mask.shape[1]+1)
            new_remain_padding_mask[:, 1:] = remain_padding_mask

            padding_mask.append(new_remain_padding_mask)
        temporal_padding_mask = torch.cat(padding_mask, dim=1)

        # Others padding mask
        if len(self.data_info.modality_info["others"]) > 0:
            others_padding_mask = padding_mask_dict["others_padding_mask"].unsqueeze(0).repeat(idx_dict["others_remain_idx"].shape[0], 1)
            others_padding_mask = others_padding_mask[:, :idx_dict["others_remain_idx"].shape[-1]]        
        
        # Total padding_mask
        total_padding_mask = torch.cat([temporal_padding_mask.to(device), others_padding_mask], dim=1)
        print(total_padding_mask.shape)


        temporal_data = torch.cat(list(temporal_dict.values()), dim=1)
        total_data = torch.cat([temporal_data, others_data], dim=1)
        print(total_data.shape)
        raise

        encoded = self.encoder(total_data, src_key_padding_mask=encoder_padding_mask)

        
        return

1==1

True

### Transformer

In [25]:
class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict, remain_rto,
                    d_model, num_layers, nhead, d_ff, dropout, activation,
                    device):
        super().__init__()
        self.data_info, self.label_encoder_dict = data_info, label_encoder_dict

        # 1. Embedding
        self.embedding_dict = self._init_embedding(d_model["encoder"]) 
        
        # 2. Apply temporal remain
        encoder_pos_enc = torch.nn.Parameter(get_positional_encoding(d_model["encoder"]), requires_grad=True)
        global_token = torch.nn.Parameter(torch.rand(1, d_model["encoder"]), requires_grad=True)
        self.temporal_remain_dict = self._init_temporal_remain(encoder_pos_enc, global_token)

        # 3. Apply modality embedding
        n_modality = len(self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"] + self.data_info.modality_info["others"])
        encoder_modality_embedding = torch.nn.Embedding(n_modality, d_model["encoder"])
        self.modality_embedding_dict = self._init_modality_embedding(encoder_modality_embedding)

        # 4. Apply others remain
        self.others_remain = OthersRemain(data_info, remain_rto["others"], device)

        # 5. Encoding
        self.encoder = Encoder(self.data_info, d_model=d_model["encoder"], nhead=nhead["encoder"], dim_feedforward=d_ff["encoder"], dropout=dropout, batch_first=True, activation=activation, norm_first=True, num_layers=num_layers["encoder"])

    def forward(self, data_dict, idx_dict, padding_mask_dict, modality):
        # 1. Embedding
        data_dict = self._apply_process(data_dict, self.embedding_dict)

        # 2. Apply temporal remain
        data_dict.update(self._apply_process(data_dict, self.temporal_remain_dict, idx_dict=idx_dict))
        
        # 3. Apply modality embedding
        data_dict.update(self._apply_process(data_dict, self.modality_embedding_dict, modality=modality))
        
        # 4. Apply others remain
        if len(self.data_info.modality_info["others"]) > 0:
            temporal_dict, others_data, idx_dict = self.others_remain(data_dict, idx_dict)
        else:
            temporal_dict = data_dict
            others_data = torch.tensor([]).to(device)

        # 5. Encoding
        self.encoder(temporal_dict, others_data, padding_mask_dict, idx_dict)
        
        return data_dict
    
    
    def _init_embedding(self, d_model):
        result_dict = {}

        # Numerical embedding
        numerical_cols = self.data_info.processing_info["scaling_cols"]
        for col in numerical_cols:
            result_dict[col] = NumericalEmbedding(d_model)

        # Categorical embedding
        categorical_cols = self.data_info.processing_info["embedding_cols"]
        for col in categorical_cols:
            num_cls = self.label_encoder_dict[col].get_num_cls()
            result_dict[col] = CategoricalEmbedding(num_cls, d_model)
        
        return torch.nn.ModuleDict(result_dict)
    
    def _init_temporal_remain(self, pos_enc, global_token):
        result_dict = {}
        target_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        for col in target_cols:
            result_dict[col] = TemporalRemain(col, pos_enc, global_token)
        
        return torch.nn.ModuleDict(result_dict)

    def _init_modality_embedding(self, modality_embedding):
        result_dict = {}
        target_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"] + self.data_info.modality_info["others"]
        for idx, col in enumerate(target_cols):
            is_others = True if col in self.data_info.modality_info["others"] else False
            result_dict[col] = ModalityEmbedding(modality_embedding, idx, is_others)
        
        return torch.nn.ModuleDict(result_dict)


    def _apply_process(self, data_dict, module, **kwargs):
        result_dict = {}

        for key, mod in module.items():
            result_dict[key] = mod(data_dict[key], **kwargs)
        
        return result_dict

1==1

True

In [26]:
def to_gpu(data, data_info, device):
    data_dict, idx_dict, padding_mask_dict = {}, {}, {}

    for col in data.keys():
        # Temporal and others data
        if col in (data_info.modality_info["target"] + data_info.modality_info["temporal"] + data_info.modality_info["others"]):
            data_dict[col] = data[col].to(device)
        elif col.endswith("idx"):
            idx_dict[col] = data[col].to(device)
        elif col.endswith("padding_mask"):
            padding_mask_dict[col] = data[col]

        # Modality data
        n_modality = len(data_info.modality_info["target"] + data_info.modality_info["temporal"] + data_info.modality_info["others"])
        modality = torch.arange(n_modality).to(device)

        # Others padding_mask
        others_padding_mask = torch.ones(len(data_info.modality_info["others"])).to(device)
        padding_mask_dict["others_padding_mask"] = others_padding_mask
        
    return data_dict, idx_dict, padding_mask_dict, modality

model = Transformer(data_info, train_dataset.label_encoder_dict, remain_rto,
                        d_model, num_layers, nhead, d_ff, dropout, "gelu",
                        device)
model.to(device)
summary(model, *to_gpu(data, data_info, device), show_parent_layers=True, print_summary=True)

torch.Size([2, 555])
torch.Size([2, 555, 64])


RuntimeError: No active exception to reraise