### Import

In [1]:
from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary
from transformers import ViTConfig, ViTModel

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


### Config

In [2]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 365
PRED_LEN = 100

modality_info = {
    "group": ["article_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "img": ["img_path"]
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
    "img_cols": ["img_path"]
}

# Model
batch_size = 32
nhead = 4
dropout = 0.1
is_identical = False

d_model = {"encoder":64, "decoder":32}
d_ff = {"encoder":64, "decoder":32}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"target": 0.25, "temporal":0.25, "img":0.25}

# Data

### Raw data

In [3]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [4]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]

data_info = DataInfo(modality_info, processing_info)

7


In [5]:
train_dataset = Dataset(df_train, data_info, remain_rto)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=24, prefetch_factor=4)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key and "img_raw" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 355.19it/s]


sales torch.Size([2, 365, 1])
sales_remain_idx torch.Size([2, 91])
sales_masked_idx torch.Size([2, 274])
sales_revert_idx torch.Size([2, 365])
sales_remain_padding_mask torch.Size([2, 91])
sales_masked_padding_mask torch.Size([2, 274])
sales_revert_padding_mask torch.Size([2, 365])
day torch.Size([2, 365])
day_remain_idx torch.Size([2, 91])
day_masked_idx torch.Size([2, 274])
day_revert_idx torch.Size([2, 365])
day_remain_padding_mask torch.Size([2, 91])
day_masked_padding_mask torch.Size([2, 274])
day_revert_padding_mask torch.Size([2, 365])
dow torch.Size([2, 365])
dow_remain_idx torch.Size([2, 91])
dow_masked_idx torch.Size([2, 274])
dow_revert_idx torch.Size([2, 365])
dow_remain_padding_mask torch.Size([2, 91])
dow_masked_padding_mask torch.Size([2, 274])
dow_revert_padding_mask torch.Size([2, 365])
month torch.Size([2, 365])
month_remain_idx torch.Size([2, 91])
month_masked_idx torch.Size([2, 274])
month_revert_idx torch.Size([2, 365])
month_remain_padding_mask torch.Size([2, 91])

# Train

### Model

In [6]:
import math
def get_positional_encoding(d_model, seq_len=1000):
    position = torch.arange(seq_len).reshape(-1,1)
    i = torch.arange(d_model)//2
    exp_term = 2*i/d_model
    div_term = torch.pow(10000, exp_term).reshape(1, -1)
    pos_encoded = position / div_term

    pos_encoded[:, 0::2] = torch.sin(pos_encoded[:, 0::2])
    pos_encoded[:, 1::2] = torch.cos(pos_encoded[:, 1::2])

    return pos_encoded

# def get_positional_encoding(d_hidn, n_seq=1000):
#     def cal_angle(position, i_hidn):
#         return position / np.power(10000, 2 * (i_hidn // 2) / d_hidn)
#     def get_posi_angle_vec(position):
#         return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]

#     sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
#     sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # even index sin 
#     sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # odd index cos

#     res = torch.tensor(sinusoid_table).to(torch.float32)
#     print(res.dtype)
#     return res

class NoneEmbedding(torch.nn.Module):
    def forward(self, data):
        if len(data.shape) == 3:
            return data.to(torch.float)
        else:
            return data.unsqueeze(-1).to(torch.float)

class NumericalEmbedding(torch.nn.Module):
    def __init__(self, d_model, is_identical):
        super().__init__()
        self.linear_embedding = torch.nn.Linear(1, d_model)
        
        if is_identical: self.linear_embedding = NoneEmbedding()
    
    def forward(self, data):
        return self.linear_embedding(data)

class CategoricalEmbedding(torch.nn.Module):
    def __init__(self, num_cls, d_model, is_identical):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_cls, d_model)
        
        if is_identical: self.embedding = NoneEmbedding()
    
    def forward(self, data):
        return self.embedding(data)

class ImgEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # config = ViTConfig(hidden_size=d_model, num_hidden_layers=num_layers, num_attention_heads=nhead, intermediate_size=d_ff, patch_size=10, encoder_stride=10)
        # self.img_model = ViTModel(config)
        self.img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.linear = torch.nn.Linear(768, d_model)
    
    def forward(self, data):
        imgimg_attnoutput, imgimg_attnweight1, imgimg_attnweight2 = self.img_model(data, output_attentions=True).values()
        imgimg_attnoutput = self.linear(imgimg_attnoutput)

        return imgimg_attnoutput

class DynamicOutput(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.output = torch.nn.Sequential(
                            # torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, 1)
                            )
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.output = torch.nn.Sequential(
                            # torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, num_cls)
                            )
    
    def forward(self, data):
        
        return self.output(data)

class MultiheadBlockAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.d_model, self.nhead = d_model, nhead
        
        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, query, key, value):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Split head
        batch_size, seq_len, _, d_model = Q.shape
        Q = Q.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0,3,1,2,4)
        K = K.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0,3,1,2,4)
        V = V.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0,3,1,2,4)

        # Scaled dot product attention
        ### 1. Q·K^t
        QK = Q @ K.permute(0,1,2,4,3)

        ### 2. Softmax
        attn = torch.nn.functional.softmax(QK/math.sqrt(self.d_model//self.nhead), dim=-1)
        
        ### 3. Matmul V
        attn_output = attn @ V
        
        # Concat heads
        attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, -1, seq_len, d_model)
        attn_output = attn_output.squeeze()

        return attn_output, attn

class FeedForward(torch.nn.Module):
    def __init__(self, d_model, d_ff, activation):
        super().__init__()
        self.linear1 = torch.nn.Linear(d_model, d_ff)
        self.linear2 = torch.nn.Linear(d_ff, d_model)
        self.dropout = torch.nn.Dropout()

        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()

    def forward(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout(x)


class CustomDecoder(torch.nn.Module):
    def __init__(self, d_model, num_heads, d_ff, activation, dropout):
        super().__init__()
        # self.cross_attn = torch.nn.MultiheadAttention(d_model, num_heads, dropout, batch_first=True)
        self.cross_attn = MultiheadBlockAttention(d_model, num_heads, dropout)
        self.mlp = torch.nn.Linear(d_model, d_model)
        self.mha = torch.nn.MultiheadAttention(d_model, num_heads, dropout, batch_first=True)
        self.ff = FeedForward(d_model, d_ff, activation)

        self.layernorm_tgt = torch.nn.LayerNorm(d_model)
        self.layernorm_mem = torch.nn.LayerNorm(d_model)

        self.layernorm1 = torch.nn.LayerNorm(d_model)
        self.layernorm2 = torch.nn.LayerNorm(d_model)
        self.layernorm3 = torch.nn.LayerNorm(d_model)

    
    def _get_padding_mask(self, padding_mask, memory, device):
        padding_mask_li = []
        for col, val in memory.items():
            padding_mask_li.append(padding_mask[f"{col}_revert_padding_mask"])
        
        result = torch.cat(padding_mask_li, dim=1).to(device)
        result = torch.where(result == 1, 0, -torch.inf)
        return result
    
    def forward(self, col, tgt, memory, padding_mask, device):
        tgt = tgt.unsqueeze(-2)
        # memory = torch.stack(list(memory.values()), dim=-2)
        
        tgt = self.layernorm_tgt(tgt)
        memory = self.layernorm_mem(memory)
        
        cross_attn, cross_attn_weight = self.cross_attn(query=tgt, key=memory, value=memory)

        cross_attn = tgt.squeeze() + cross_attn

        padding_mask = padding_mask[f"{col}_revert_padding_mask"]
        padding_mask = torch.where(padding_mask == 1, 0, -torch.inf)
        
        cross_attn = self.layernorm1(cross_attn)
        self_attn, self_attn_weight = self.mha(cross_attn, cross_attn, cross_attn, key_padding_mask=padding_mask) # multiheadattention

        self_attn = self_attn + cross_attn

        # ff
        self_attn = self.layernorm2(self_attn)
        ff = self_attn + self.ff(self_attn)
        
        return ff, self_attn_weight, cross_attn_weight

1==1

True

In [25]:
class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict,
                    d_model, num_layers, nhead, d_ff, dropout,
                    is_identical=False):
        super().__init__()
        activation = "gelu"
        self.data_info, self.label_encoder_dict, self.is_identical = data_info, label_encoder_dict, is_identical

        # 1. Embedding
        self.numerical_embedding_dict = self._init_numerical_embedding_dict(d_model["encoder"])
        self.categorical_embedding_dict = self._init_categorical_embedding_dict(d_model["encoder"])
        self.img_embedding = ImgEmbedding(d_model["encoder"])

        # 3. Apply remain positional encoding
        self.encoder_pos_enc = torch.nn.Parameter(get_positional_encoding(d_model["encoder"]), requires_grad=False)  if not self.is_identical else torch.nn.Parameter(torch.zeros(1000, 1))

        # 4. Apply modality embedding
        self.num_modality = len(self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"] + self.data_info.modality_info["img"])
        self.encoder_mod_emb = torch.nn.Embedding(self.num_modality, d_model["encoder"])  if not self.is_identical else NoneEmbedding()

        # 5. Encoding
        self.encoder = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model=d_model["encoder"], nhead=nhead, dim_feedforward=d_ff["encoder"], dropout=dropout, batch_first=True, activation=activation, norm_first=True), num_layers["encoder"])  if not self.is_identical else NoneEmbedding()
        self.to_decoder_dim = torch.nn.Linear(d_model["encoder"], d_model["decoder"])  if not self.is_identical else NoneEmbedding()

        # 7. Revert
        self.mask_token = torch.nn.Parameter(torch.rand(1, d_model["decoder"])) if not self.is_identical else torch.nn.Parameter(torch.zeros(1, 1) + 99)

        # 8. Apply revert positional encoding
        self.decoder_pos_enc = torch.nn.Parameter(get_positional_encoding(d_model["decoder"]), requires_grad=False)  if not self.is_identical else torch.nn.Parameter(torch.zeros(1000, 1))

        # 9. Apply modality embedding
        self.decoder_mod_emb = torch.nn.Embedding(self.num_modality, d_model["decoder"])  if not self.is_identical else NoneEmbedding()

        # 10. Decoding
        self.custom_decoder = self._init_custom_deocder_dict(d_model["decoder"], nhead, d_ff["decoder"], activation, dropout)
        self.img_decoder = torch.nn.TransformerDecoder(torch.nn.TransformerDecoderLayer(d_model=d_model["decoder"], nhead=nhead, dim_feedforward=d_ff["decoder"], dropout=dropout, batch_first=True, activation=activation, norm_first=True), num_layers["decoder"])
        
        # 11. Output
        self.temporal_output = self._init_temporal_output(d_model["decoder"])
        patch_size = 16
        self.img_output = torch.nn.Sequential(torch.nn.Linear(d_model["decoder"], d_model["decoder"]), torch.nn.Linear(d_model["decoder"], d_model["decoder"]), torch.nn.Linear(d_model["decoder"], 3*patch_size*patch_size))
        
    def forward(self, input_data_dict, remain_rto, device):
        # 0. Data to gpu
        data_dict, idx_dict, padding_mask_dict = self._to_gpu(input_data_dict, device)

        # 1. Embedding
        embedding_dict = {}
        embedding_dict.update(self._apply_numerical_embedding(data_dict))
        embedding_dict.update(self._apply_categorical_embedding(data_dict))
        embedding_dict.update(self._apply_img_embedding(data_dict))

        # 2. Apply remain
        temporal_remain = self._apply_temporal_remain(embedding_dict, idx_dict)
        img_remain, img_remain_idx, img_masked_idx, img_revert_idx = self._apply_img_remain(embedding_dict["img_input"])
        idx_dict["img_remain_idx"] = img_remain_idx
        idx_dict["img_masked_idx"] = img_masked_idx
        idx_dict["img_revert_idx"] = img_revert_idx

        # 3. Apply temporal remain positional encoding
        temporal_remain = self._apply_remain_temporal_positional_encoding(temporal_remain, idx_dict, self.encoder_pos_enc)

        # 4. Apply modality embedding
        temporal_remain, img_remain = self._apply_modality_embedding(temporal_remain, img_remain, self.encoder_mod_emb, device)

        # 5. Encoding
        temporal_encoder_padding_mask = self._get_temporal_padding_mask(padding_mask_dict, temporal_remain, device, mode="remain")
        img_encoder_padding_mask = torch.ones(img_remain.shape[:-1]).to(device)
        padding_mask_dict["img_remain_padding_mask"] = img_encoder_padding_mask
        encoder_padding_mask = torch.cat([temporal_encoder_padding_mask, img_encoder_padding_mask], dim=1)
        
        concat = torch.cat(list(temporal_remain.values()), dim=1)
        concat = torch.cat([concat, img_remain], dim=1)

        encoded = self.encoder(concat, src_key_padding_mask=encoder_padding_mask)
        encoded = self.to_decoder_dim(encoded)

        # 6. Split
        temporal_encoded, img_encoded = self._split_modalities(encoded, temporal_remain, img_remain)

        # 7. Revert
        temporal_reverted = self._apply_temporal_revert(temporal_encoded, idx_dict, padding_mask_dict)
        img_reverted = self._apply_temporal_revert({"img":img_encoded}, idx_dict, padding_mask_dict)["img"]

        # 8. Apply revert positional encoding
        temporal_reverted = self._apply_reverted_temporal_positional_encoding(temporal_reverted, idx_dict, self.decoder_pos_enc)
        img_reverted = img_reverted + get_positional_encoding(d_model["decoder"])[:img_reverted.shape[1], :].to(device)

        # 9. Apply modality embedding
        temporal_reverted, img_reverted = self._apply_modality_embedding(temporal_reverted, img_reverted, self.decoder_mod_emb, device)

        # 10. Decoding
        temporal_decoded, img_decoded, self_attn_dict, cross_attn_dict = self._apply_custom_decoder(temporal_reverted, img_reverted, padding_mask_dict, device)

        # 12. Ouptut
        temporal_output = self._apply_temporal_output(temporal_decoded)
        img_output = self.img_output(img_decoded)

        return temporal_output, img_output, self_attn_dict, cross_attn_dict, idx_dict

    
    def _to_gpu(self, data, device):
        data_dict = {}
        idx_dict = {}
        padding_mask_dict = {}

        for col in data.keys():
            if col.endswith("padding_mask"):
                padding_mask_dict[col] = data[col].to(device)
            
            elif col.endswith("idx"):
                idx_dict[col] = data[col].to(device)

            elif col in self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]:
                data_dict[col] = data[col].to(device)
            
            elif col == "img_input":
                img_input = data["img_input"]
                data_dict["img_input"] = img_input.to(device)
                padding_mask_dict["img_input"] = torch.ones(img_input.shape).to(device)
        
        return data_dict, idx_dict, padding_mask_dict
    
    def _init_numerical_embedding_dict(self, d_model):
        result_dict = {}
        target_cols = self.data_info.processing_info["scaling_cols"]
        for col in target_cols:
            result_dict[col] = NumericalEmbedding(d_model, self.is_identical) if not self.is_identical else NoneEmbedding()
        
        return torch.nn.ModuleDict(result_dict)

    def _init_categorical_embedding_dict(self, d_model):
        result_dict = {}
        target_cols = self.data_info.processing_info["embedding_cols"]
        for col in target_cols:
            num_cls = self.label_encoder_dict[col].get_num_cls()
            result_dict[col] = CategoricalEmbedding(num_cls, d_model, self.is_identical) if not self.is_identical else NoneEmbedding()
        
        return torch.nn.ModuleDict(result_dict)

    def _init_temporal_output(self, d_model):
        result_dict = {}
        for col in self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]:
            result_dict[col] = DynamicOutput(col, self.data_info, self.label_encoder_dict, d_model) if not self.is_identical else NoneEmbedding()
        
        return torch.nn.ModuleDict(result_dict)

    def _init_custom_deocder_dict(self, d_model, num_heads, d_ff, activation, dropout):
        result_dict = {}
        target_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        for col in target_cols:
            result_dict[col] = torch.nn.ModuleList([CustomDecoder(d_model, num_heads, d_ff, activation, dropout),
                                                    CustomDecoder(d_model, num_heads, d_ff, activation, dropout),
                                                    # CustomDecoder(d_model, num_heads, d_ff, activation, dropout),
                                                    # CustomDecoder(d_model, num_heads, d_ff, activation, dropout)
                                                    ])
        return torch.nn.ModuleDict(result_dict)


    def _apply_numerical_embedding(self, data):
        result_dict = {}
        target_cols = self.data_info.processing_info["scaling_cols"]
        
        for col in target_cols:
            result_dict[col] = self.numerical_embedding_dict[col](data[col])
        
        return result_dict

    def _apply_categorical_embedding(self, data):
        result_dict = {}
        target_cols = self.data_info.processing_info["embedding_cols"]
        for col in target_cols:
            result_dict[col] = self.categorical_embedding_dict[col](data[col])
        
        return result_dict

    def _apply_img_embedding(self, data):
        result_dict = {}
        result_dict["img_input"] = self.img_embedding(data["img_input"])
        return result_dict

    def _apply_temporal_remain(self, embedding_dict, idx_dict):
        result_dict = {}
        target_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        for col, val in {key:val for key, val in embedding_dict.items() if key in target_cols}.items():
            # Get remain data
            remain_idx = idx_dict[f"{col}_remain_idx"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            val = torch.gather(val, index=remain_idx, dim=1)

            result_dict[col] = val
        return result_dict

    def _apply_img_remain(self, data):
        # Positional encoding
        data = data[:, 1:, :]
        pos_enc = get_positional_encoding(d_model["encoder"]).to(device)
        data += pos_enc[:data.shape[1], :]

        # Apply mask
        num_remain = int(data.shape[1] * remain_rto["img"])
        
        noise = torch.rand(data.shape[:-1]).to(device)
        shuffle_idx = torch.argsort(noise, dim=1)
        
        remain_idx = shuffle_idx[:, :num_remain]
        masked_idx = shuffle_idx[:, num_remain:]
        revert_idx = torch.argsort(shuffle_idx, dim=1)

        remain_data = torch.gather(data, index=remain_idx.unsqueeze(-1).repeat(1, 1, data.shape[-1]), dim=1)

        return remain_data, remain_idx, masked_idx, revert_idx

    def _apply_remain_temporal_positional_encoding(self, temporal_embedding, idx_dict, pos_enc):
        result_dict = {}
        
        for col, val in temporal_embedding.items():
            # Get remain pos_enc
            remain_idx = idx_dict[f"{col}_remain_idx"].unsqueeze(-1).repeat(1, 1, pos_enc.shape[-1])
            remain_pos_enc = torch.gather(pos_enc.unsqueeze(0).repeat(val.shape[0], 1, 1), index=remain_idx, dim=1)
            
            result_dict[col] = val + remain_pos_enc
        
        return result_dict

    def _apply_modality_embedding(self, temporal_embedding, img_embedding, mod_emb, device):
        temporal_embedding_result = {}
        modality_idx = 0

        # Temporal
        for col, val in temporal_embedding.items():
            modality = torch.zeros(val.shape[1]).to(device) + modality_idx
            modality = mod_emb(modality.to(torch.int))
            temporal_embedding_result[col] = val + modality

            modality_idx += 1
        
        # Img
        modality = torch.zeros(img_embedding.shape[1]).to(device) + modality_idx
        modality = mod_emb(modality.to(torch.int))
        img_embedding_result = img_embedding + modality
        modality_idx += 1

        assert modality_idx == self.num_modality
        return temporal_embedding_result, img_embedding_result

    def _get_temporal_padding_mask(self, padding_mask_dict, temporal_data, device, mode):
        padding_mask_li = []
        for col, val in temporal_data.items():
            padding_mask_li.append(padding_mask_dict[f"{col}_{mode}_padding_mask"])
        
        result = torch.cat(padding_mask_li, dim=1).to(device)
        result = torch.where(result == 1, 0, -torch.inf)
        return result   

    def _split_modalities(self, data, temporal_data, img_data):
        temporal_result_dict = {}
        start_idx = 0
        
        # Temporal
        for col, val in temporal_data.items():
            length = val.shape[1]
            temporal_result_dict[col] = data[:, start_idx: start_idx+length, :]
            start_idx += length
        
        # Img
        length = img_data.shape[1]
        img_result = data[:, start_idx:start_idx+length, :]
        start_idx += length

        
        assert start_idx == data.shape[1]
        return temporal_result_dict, img_result

    def _apply_temporal_revert(self, temporal_encoded, idx_dict, padding_mask_dict):
        result_dict = {}

        for col, val in temporal_encoded.items():
            # Replace remain padding to mask token
            remain_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            val = torch.where(remain_padding_mask==1, val, self.mask_token)
            
            # Append mask token
            revert_idx = idx_dict[f"{col}_revert_idx"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            mask_tokens = self.mask_token.unsqueeze(0).repeat(val.shape[0],
                                                    revert_idx.shape[1] - val.shape[1],
                                                    1)
            val_with_mask_token = torch.cat([val, mask_tokens], dim=1)
            assert val_with_mask_token.shape == revert_idx.shape

            # Apply revert
            reverted_val = torch.gather(val_with_mask_token, index=revert_idx, dim=1)

            result_dict[col] = reverted_val

        return result_dict

    def _apply_reverted_temporal_positional_encoding(self, temporal_reverted, idx_dict, pos_enc):
        result_dict = {}
        
        for col, val in temporal_reverted.items():
            result_dict[col] = val + pos_enc.unsqueeze(0).repeat(val.shape[0], 1, 1)[:, :val.shape[1], :]
        
        return result_dict

    def _apply_temporal_output(self, temporal_decoded):
        result_dict = {}
        for col, val in temporal_decoded.items():
            result_dict[col] = self.temporal_output[col](val)

        return result_dict

    def _apply_custom_decoder(self, temporal, img, padding_mask_dict, device):
        temporal_result_dict = {}
        img_result_dict = {}
        self_attn_dict, cross_attn_dict = {}, {}
        img_result = None
        target_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"] + self.data_info.modality_info["img"]

        for col in target_cols:
            if col in self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]:
                tgt = temporal[col]
                temporal_memory = {key:val for key, val in temporal.items() if col != key}
                img_memory = img
                memory = torch.stack(list(temporal_memory.values()), dim=-2)
                memory = torch.cat([memory, img.unsqueeze(1).repeat(1, memory.shape[1], 1, 1)], dim=-2)

                for mod in self.custom_decoder[col]:
                    tgt, self_attn_weight, cross_attn_weight = mod(col, tgt, memory, padding_mask_dict, device)
                
                temporal_result_dict[col] = tgt
                self_attn_dict[col] = self_attn_weight
                cross_attn_dict[col] = cross_attn_weight

            elif col in self.data_info.modality_info["img"]:
                tgt = img
                memory = torch.cat(list(temporal.values()), dim=-2)
                padding_mask = self._get_temporal_padding_mask(padding_mask_dict, temporal, device, mode="revert")
                decoding = self.img_decoder(tgt=img, memory=memory, memory_key_padding_mask=padding_mask)
                img_result = decoding

        return temporal_result_dict, img_result, self_attn_dict, cross_attn_dict


    def _____apply_custom_decoder(self, reverted_dict, padding_mask_dict, device):
        result_dict = {}
        self_attn_dict, cross_attn_dict = {}, {}
        for col in reverted_dict.keys():

            tgt = reverted_dict[col]
            memory = {key:val for key, val in reverted_dict.items() if col != key}
            # print(f"{col} - {memory.keys()}")
            for mod in self.custom_decoder[col]:
                tgt, self_attn_weight, cross_attn_weight = mod(col, tgt, memory, padding_mask_dict, device)
            result_dict[col] = tgt
            self_attn_dict[col] = self_attn_weight
            cross_attn_dict[col] = cross_attn_weight

        return result_dict, self_attn_dict, cross_attn_dict

1==1

True

In [26]:
model = Transformer(data_info, train_dataset.label_encoder_dict,
                        d_model, num_layers, nhead, d_ff, dropout,
                        is_identical)
model.to(device)
summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

--------------------------------------------------------------------------------------------
   Parent Layers            Layer (type)        Output Shape         Param #     Tr. Param #
     Transformer          ImgEmbedding-1        [2, 197, 64]      86,438,464      86,438,464
     Transformer             Embedding-2            [91, 64]             448             448
     Transformer    TransformerEncoder-3        [2, 595, 64]          50,432          50,432
     Transformer                Linear-4        [2, 595, 32]           2,080           2,080
     Transformer             Embedding-5           [365, 32]             224             224
     Transformer    TransformerDecoder-6        [2, 196, 32]          21,504          21,504
     Transformer                Linear-7        [2, 196, 32]           1,056           1,056
     Transformer                Linear-8        [2, 196, 32]           1,056           1,056
     Transformer                Linear-9       [2, 196, 768]          



In [31]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
mse_loss = torch.nn.MSELoss(reduction="none")
ce_loss = torch.nn.CrossEntropyLoss(reduction="none")

def temporal_loss(col, pred, y, data, device):
    pred = pred.to(device).squeeze()
    y = y.to(device).squeeze()

    masked_idx = data[f"{col}_masked_idx"].squeeze().to(device)
    padding_mask = data[f"{col}_masked_padding_mask"].squeeze().to(device)

    masked_y = torch.gather(y, index=masked_idx, dim=1)

    if col in data_info.processing_info["embedding_cols"]:
        masked_idx = masked_idx.unsqueeze(-1).repeat(1, 1, pred.shape[-1])
        masked_pred = torch.gather(pred, index=masked_idx, dim=1)
        
        masked_pred = masked_pred.reshape(-1, masked_pred.shape[-1])
        masked_y = masked_y.reshape(-1)
        
        loss = ce_loss(masked_pred, masked_y.to(torch.long)).reshape(padding_mask.shape)
    
    elif col in data_info.processing_info["scaling_cols"]:
        masked_pred = torch.gather(pred, index=masked_idx, dim=1)
        loss = mse_loss(masked_pred, masked_y)
    
    loss = loss * padding_mask
    loss = sum(loss.view(-1)) / sum(padding_mask.view(-1))
    
    return loss

def plot_epoch_loss():
    pass

def plot_sample_loss(model, data, remain_rto, device, train_mean_loss_dict, eval_mean_loss_dict):
    model.eval()
    # Get eval prediction
    with torch.no_grad():
        temporal_sample_output, img_output, self_attn_dict, cross_attn_dict, idx_dict = model(data, remain_rto, device)
        
        # Temporal train loss
        loss_dict = defaultdict(list)
        for col, val in temporal_sample_output.items():
            y = data[col]

            if col in data_info.processing_info["scaling_cols"]:
                scaler = data[f"{col}_scaler"]
                
                for n, (s, _, _) in enumerate(zip(scaler, y, val)):
                    y[n] = torch.tensor(s.inverse_transform(y[n].clone().detach()))
                    val[n] = torch.tensor(s.inverse_transform(val[n].clone().detach().cpu()))

            loss = temporal_loss(col, val, y, data, device)
            loss_dict[col].append(loss)
        
        latest_loss = {key:val[-1] for key, val in loss_dict.items()}
        temporal_loss_val = torch.nansum(torch.stack(list(latest_loss.values())))

        # Img loss
        y = patchify(data["img_input"].squeeze()).to(device)
        masked_idx = idx_dict[f"img_masked_idx"].squeeze().to(device).unsqueeze(-1).repeat(1, 1, y.shape[-1])
        masked_y = torch.gather(y, index=masked_idx, dim=1)

        masked_pred = torch.gather(img_output, index=masked_idx, dim=1)
        img_loss = mse_loss(masked_pred, masked_y)

        img_loss = torch.mean(img_loss)

        # Dictionary for plot
        eval_mean_loss_dict["total_loss"].append(loss.item())
        for key, val in loss_dict.items():
            val = np.array([i.item() for i in val])
            eval_mean_loss_dict[key].append(val.mean())

    # Plot
    idx = 0
    plt.figure(figsize=(25,13))
    nrows, ncols = len(data_info.modality_info["target"] + data_info.modality_info["temporal"])+1, 5
    clear_output(wait=True)

    # Plot loss
    plt.subplot(nrows, ncols, 1)
    plt.plot(train_mean_loss_dict["total_loss"])
    plt.title(f'Train total loss: {train_mean_loss_dict["total_loss"][-1]}')

    # Plot loss
    plt.subplot(nrows, ncols, 2)
    plt.plot(eval_mean_loss_dict["total_loss"])
    plt.title(f'Eval total loss: {eval_mean_loss_dict["total_loss"][-1]}')

    plot_idx = 6
    for key, val in temporal_sample_output.items():
        # Plot loss
        plt.subplot(nrows, ncols, plot_idx)
        plt.plot(train_mean_loss_dict[key])
        plt.title(f"{key} train loss: {train_mean_loss_dict[key][-1]}")
        plot_idx += 1

        # Plot loss
        plt.subplot(nrows, ncols, plot_idx)
        plt.plot(eval_mean_loss_dict[key])
        plt.title(f"{key} eval loss: {eval_mean_loss_dict[key][-1]}")
        plot_idx += 1

        # Plot sample
        masked_idx = data[f"{key}_masked_idx"][idx]
        y = data[key][idx].squeeze()
        sample = val[idx].squeeze().detach().cpu()
        if key in data_info.processing_info["embedding_cols"]:
            sample = torch.argmax(sample, dim=-1).to(torch.float)
        
        sample_full = torch.zeros(y.shape) + torch.nan
        sample = torch.gather(sample, index=masked_idx, dim=-1)
        sample_full[masked_idx] = sample

        plt.subplot(nrows, ncols, plot_idx)
        plt.plot(y)
        plt.plot(sample_full)
        plt.scatter(torch.arange(sample_full.shape[0]), sample_full, color="red", s=15)
        plt.title(f"{key} sample")
        plot_idx += 1

        # Plot self attention
        plt.subplot(nrows, ncols, plot_idx)
        # print(self_attn_dict.shape)
        # raise
        sns.heatmap(self_attn_dict[key][idx].detach().cpu())
        plt.title(f"{key} self attention weight")
        plot_idx += 1

        # Plot cross attention
        plt.subplot(nrows, ncols, plot_idx)
        cross_attn = cross_attn_dict[key][idx].squeeze().detach().cpu() 
        cross_attn = cross_attn.mean(dim=0)
        sns.heatmap(cross_attn)
        plt.title(f"{key} cross attention weight")
        plot_idx += 1

    return eval_mean_loss_dict
        
    
1==1

True

In [32]:
def patchify(imgs):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    patch_size = 16
    p = patch_size
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
    return x

In [33]:
def train(e):
    pbar = tqdm(train_dataloader)
    loss_dict, train_mean_loss_dict, eval_mean_loss_dict = defaultdict(list), defaultdict(list), defaultdict(list)

    for n, data in enumerate(pbar):
        optimizer.zero_grad()
        model.train()
        temporal_output, img_output, self_attn_dict, cross_attn_dict, idx_dict = model(data, remain_rto, device)

        # Temporal train loss
        for col, val in temporal_output.items():
            y = data[col]
            loss = temporal_loss(col, val, y, data, device)
            loss_dict[col].append(loss)
        
        latest_loss = {key:val[-1] for key, val in loss_dict.items()}
        temporal_loss_val = torch.nansum(torch.stack(list(latest_loss.values())))

        # Img tarin loss
        y = patchify(data["img_input"].squeeze()).to(device)
        masked_idx = idx_dict[f"img_masked_idx"].squeeze().to(device).unsqueeze(-1).repeat(1, 1, y.shape[-1])
        masked_y = torch.gather(y, index=masked_idx, dim=1)

        masked_pred = torch.gather(img_output, index=masked_idx, dim=1)
        img_loss = mse_loss(masked_pred, masked_y)

        img_loss = torch.mean(img_loss)

        loss = img_loss + temporal_loss_val

        loss.backward()
        optimizer.step()

        # Dictionary for plot
        train_mean_loss_dict["total_loss"].append(loss.item())
        for key, val in loss_dict.items():
            val = np.array([i.item() for i in val])
            train_mean_loss_dict[key].append(val.mean())

        if n % 20 == 0:
            plot_epoch_loss()
            eval_mean_loss_dict = plot_sample_loss(model, data, remain_rto, device, train_mean_loss_dict, eval_mean_loss_dict)
            plt.tight_layout()
            plt.show()
    
    return 

epoch = 10
for e in range(epoch):
    train(e)
    scheduler.step()

  0%|          | 0/1 [00:05<?, ?it/s]


KeyboardInterrupt: 