### Import

In [1]:
import copy

from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary
from transformers import ViTModel, AutoImageProcessor, BertModel, AutoTokenizer

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

from architecture import Transformer

import cv2

device = torch.device("cuda")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


### Config

In [2]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 100
PRED_LEN = 100

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "img": ["img_path"],
    "nlp": ["detail_desc"]
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
    "img_cols": ["img_path"],
    "nlp_cols": ["detail_desc"]
}

# Model
batch_size = 16
nhead = 4
dropout = 0.1
patch_size = 16

d_model = {"encoder":256, "decoder":128}
d_ff = {"encoder":256, "decoder":128}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"temporal":0.75, "img":0.25, "nlp":0.25}

# Data

### Raw data

In [3]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [4]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_train = df_train[~pd.isna(df_train["detail_desc"])]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]
df_valid = df_valid[~pd.isna(df_valid["detail_desc"])]

data_info = DataInfo(modality_info, processing_info)

In [5]:
train_dataset = Dataset(df_train, data_info, remain_rto)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key and "raw" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 2640.42it/s]


sales torch.Size([2, 100, 1])
day torch.Size([2, 100])
dow torch.Size([2, 100])
month torch.Size([2, 100])
holiday torch.Size([2, 100])
price torch.Size([2, 100, 1])
temporal_padding_mask torch.Size([2, 100, 1])
img_path torch.Size([2, 3, 224, 224])
detail_desc torch.Size([2, 9])
detail_desc_remain_idx torch.Size([2, 2])
detail_desc_masked_idx torch.Size([2, 6])
detail_desc_revert_idx torch.Size([2, 8])
detail_desc_remain_padding_mask torch.Size([2, 2])
detail_desc_masked_padding_mask torch.Size([2, 6])
detail_desc_revert_padding_mask torch.Size([2, 8])


# Architecture

In [6]:
import math
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = torch.permute(x, (1,0,2))
        x = x + self.pe[:x.size(0)]
        x = torch.permute(x, (1,0,2))
        return self.dropout(x)

class MultiheadBlockSelfAttention(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.nhead = nhead

        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)

        self.static_attention = torch.nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

    def forward(self, query, key, value, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        temporal_attn_output, temporal_attn_weight = self.temporal_side_attention(Q, K, V, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask)
        static_attn_output, static_attn_weight = self.static_side_attention(Q, K, V, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask)

        # Concat all
        temporal_attn_output = temporal_attn_output.view(temporal_attn_output.shape[0], -1, temporal_attn_output.shape[-1])
        attn_output = torch.cat([temporal_attn_output, static_attn_output], dim=1)

        return attn_output

    def temporal_side_attention(self, Q, K, V, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask):
        # Obtain QKV
        temporal_idx = temporal_idx.unsqueeze(0).unsqueeze(-1).repeat(Q.shape[0], 1, Q.shape[-1])
        static_idx = static_idx.unsqueeze(0).unsqueeze(-1).repeat(Q.shape[0], 1, Q.shape[-1])

        Qt = torch.gather(Q, index=temporal_idx, dim=1).view(temporal_shape)

        Ktt = torch.gather(K, index=temporal_idx, dim=1).view(temporal_shape)
        Kts = torch.gather(K, index=static_idx, dim=1).unsqueeze(1).repeat(1, Ktt.shape[1], 1, 1)
        Kt = torch.cat([Ktt, Kts], dim=-2)

        Vtt = torch.gather(V, index=temporal_idx, dim=1).view(temporal_shape)
        Vts = torch.gather(V, index=static_idx, dim=1).unsqueeze(1).repeat(1, Vtt.shape[1], 1, 1)
        Vt = torch.cat([Vtt, Vts], dim=-2)

        # Obtain key padding mask
        static_padding_mask = static_padding_mask.unsqueeze(1).repeat(1, temporal_padding_mask.shape[1], 1)

        key_padding_mask = torch.cat([temporal_padding_mask, static_padding_mask], dim=-1)
        key_padding_mask = torch.where(key_padding_mask==1, 0, -torch.inf)

        # Attention
        temporal_attn_output, temporal_attn_weight = self.multihead_block_attention(Qt, Kt, Vt, key_padding_mask=key_padding_mask)
        return temporal_attn_output, temporal_attn_weight
    
    def static_side_attention(self, Q, K, V, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask):
        # Obtaion QKV
        temporal_idx = temporal_idx.unsqueeze(0).unsqueeze(-1).repeat(Q.shape[0], 1, Q.shape[-1])
        static_idx = static_idx.unsqueeze(0).unsqueeze(-1).repeat(Q.shape[0], 1, Q.shape[-1])

        Qs = torch.gather(Q, index=static_idx, dim=1)
        
        Kst = torch.gather(K, index=temporal_idx, dim=1)
        Kss = torch.gather(K, index=static_idx, dim=1)
        Ks = torch.cat([Kst, Kss], dim=1)

        Vst = torch.gather(V, index=temporal_idx, dim=1)
        Vss = torch.gather(V, index=static_idx, dim=1)
        Vs = torch.cat([Vst, Vss], dim=1)

        # Obtain key padding mask
        temporal_padding_mask = temporal_padding_mask.view(temporal_shape[0], -1)
        static_padding_mask = static_padding_mask
        
        key_padding_mask = torch.cat([temporal_padding_mask, static_padding_mask], dim=1)
        key_padding_mask = torch.where(key_padding_mask==1, 0, -torch.inf)

        # Attention
        static_attn_output, static_attn_weight = self.static_attention(Qs, Ks, Vs, key_padding_mask=key_padding_mask)

        return static_attn_output, static_attn_weight 


    def multihead_block_attention(self, Q, K, V, key_padding_mask):
        # Split head
        batch_size, seq_len, _, d_model = Q.shape
        Q = Q.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        K = K.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        V = V.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)

        # Scaled dot product attention
        ### 1. Q·K^t
        QK = Q @ K.permute(0,1,2,4,3)
        logits = QK / math.sqrt(d_model//self.nhead)
        
        ### #. Padding_mask
        key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(-2).repeat(1, logits.shape[1], 1, logits.shape[-2], 1)
        logits += key_padding_mask

        ### 2. Softmax
        attn_weight = torch.nn.functional.softmax(logits, dim=-1)

        ### 3. Matmul V
        attn_output = attn_weight @ V

        ### 4. Concat heads
        attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, -1, d_model)

        return attn_output, attn_weight

class MultiheadBlockAttention(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.nhead = nhead

        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, query, key, value):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Split head
        batch_size, seq_len, _, d_model = Q.shape
        Q = Q.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        K = K.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        V = V.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)

        # Scaled dot product attention
        ### 1. Q·K^t
        QK = Q @ K.permute(0,1,2,4,3)

        ### 2. Softmax
        attn_weight = torch.nn.functional.softmax(QK / math.sqrt(d_model//self.nhead), dim=-1)

        ### 3. Matmul V
        attn_output = attn_weight @ V

        ### 4. Concat heads
        attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, -1, d_model)

        return attn_output, attn_weight

1==1

True

In [68]:
class TemporalEmbedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.embedding = torch.nn.Linear(1, d_model)
        
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.embedding = torch.nn.Embedding(num_cls, d_model)
    
    def forward(self, key, val, padding_mask_dict, device):
        return self.embedding(val)

class ImgEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.downsize_linear = torch.nn.Linear(768, d_model)
    
    def forward(self, key, val, padding_mask_dict, device):
        embedding = self.img_model(val).last_hidden_state
        embedding = self.downsize_linear(embedding)
        return embedding

class NlpEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.nlp_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
        self.downsize_linear = torch.nn.Linear(768, d_model)   
    
    def forward(self, key, val, padding_mask_dict, device):
        # Make token_type_ids
        token_type_ids = torch.zeros(val.shape).to(torch.int).to(device)

        # Make attention mask
        attention_mask = padding_mask_dict[f"{key}_revert_padding_mask"]
        mask_for_global_token = torch.ones(attention_mask.shape[0], 1).to(device)
        attention_mask = torch.cat([attention_mask, mask_for_global_token], dim=-1)

        # Embed data
        inputs = {"input_ids":val, "token_type_ids":token_type_ids, "attention_mask":attention_mask}
        embedding = self.nlp_model(**inputs).last_hidden_state
        embedding = self.downsize_linear(embedding)
        
        return embedding


class TemporalRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, padding_mask_dict, remain_rto, temporal_cols, device):
        result_dict, idx_dict = {}, {}
        # Concat data
        concat_data_li, concat_padding_mask_li = [], []
        for col in temporal_cols:
            concat_data_li.append(data_dict[col])
            concat_padding_mask_li.append(padding_mask_dict["temporal_padding_mask"])
        
        concat_data = torch.stack(concat_data_li, dim=-2)
        concat_padding_mask = torch.cat(concat_padding_mask_li, dim=-1)
    
        # Remain mask
        num_modality = concat_data.shape[-2]
        num_remain = int(num_modality * remain_rto)
        
        noise = torch.rand(concat_data.shape[:-1]).to(device)
        shuffle_idx = torch.argsort(noise, dim=-1)

        remain_idx = shuffle_idx[:, :, :num_remain]
        masked_idx = shuffle_idx[:, :, num_remain:]
        revert_idx = torch.argsort(shuffle_idx, dim=-1)

        # Apply mask
        concat_data = torch.gather(concat_data, index=remain_idx.unsqueeze(-1).repeat(1, 1, 1, concat_data.shape[-1]), dim=-2)
        concat_padding_mask = torch.gather(concat_padding_mask, index=remain_idx, dim=-1)


        result_dict["temporal"] = concat_data
        idx_dict.update({"temporal_remain_idx":remain_idx, "temporal_masked_idx":masked_idx, "temporal_revert_idx":revert_idx})
        padding_mask_dict.update({"temporal_remain_padding_mask":concat_padding_mask})

        return result_dict, idx_dict, padding_mask_dict

class ImgRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, remain_rto, img_cols, device):
        result_dict, idx_dict, padding_mask_dict = {}, {}, {}
        # Get indexs
        for col in img_cols:
            val = data_dict[col]
            
            # Split global token
            global_token = val[:, :1, :]
            val = val[:, 1:, :]

            # Apply remain
            num_remain = int(val.shape[1] * remain_rto)
            noise = torch.rand(val.shape[0], val.shape[1]).to(device)
            shuffle_idx = torch.argsort(noise, dim=1)

            remain_idx = shuffle_idx[:, :num_remain]
            masked_idx = shuffle_idx[:, num_remain:]
            revert_idx = torch.argsort(shuffle_idx, dim=1)

            remain_padding_mask = torch.ones(remain_idx.shape[0], remain_idx.shape[1]+1).to(device)
            revert_padding_mask = torch.ones(revert_idx.shape[0], revert_idx.shape[1]+1).to(device)

            # Apply mask
            val = torch.gather(val, index=remain_idx.unsqueeze(-1).repeat(1, 1, val.shape[-1]), dim=1)
            val = torch.cat([global_token, val], dim=1)

            result_dict[col] = val
            idx_dict.update({f"{col}_remain_idx":remain_idx, f"{col}_masked_idx":masked_idx, f"{col}_revert_idx":revert_idx})
            padding_mask_dict.update({f"{col}_remain_padding_mask":remain_padding_mask, f"{col}_revert_padding_mask":revert_padding_mask})

        return result_dict, idx_dict, padding_mask_dict

class NlpRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, idx_dict, remain_rto, nlp_cols, device):
        result_dict = {}
        for col in nlp_cols:
            val = data_dict[col]
            remain_idx = idx_dict[f"{col}_remain_idx"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            val = torch.gather(val, index=remain_idx, dim=1)
            result_dict[col] = val
        return result_dict


class EncoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
        self.self_attn = MultiheadBlockSelfAttention(d_model, nhead, dropout)

        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.dropout1 = torch.nn.Dropout(dropout)
        
        # Feed forward
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()
        
        self.linear_ff1 = torch.nn.Linear(d_model, d_ff)
        self.linear_ff2 = torch.nn.Linear(d_ff, d_model)
        self.dropout_ff1 = torch.nn.Dropout(dropout)
        self.dropout_ff2 = torch.nn.Dropout(dropout)

    def forward(self, src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask):
        x = src
        attn_output = self._sa_block(self.norm1(x), temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask)
        x = x + attn_output

        x = x + self._ff_block(self.norm2(x))
        return x

    def _sa_block(self, src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask):
        x = self.self_attn(src, src, src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask)
        return self.dropout1(x)

    def _ff_block(self, x):
        x = self.linear_ff2(self.dropout_ff1(self.activation(self.linear_ff1(x))))
        return self.dropout_ff2(x)


class TemporalRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, temporal, idx_dict, temporal_cols):
        # Append mask token
        revert_idx = idx_dict["temporal_revert_idx"]
        mask_token = self.mask_token.unsqueeze(0).unsqueeze(1).repeat(temporal.shape[0], temporal.shape[1], revert_idx.shape[-1] - temporal.shape[-2], 1)
        temporal = torch.cat([temporal, mask_token], dim=-2)

        # Apply revert
        temporal = torch.gather(temporal, index=revert_idx.unsqueeze(-1).repeat(1, 1, 1, temporal.shape[-1]), dim=-2)

        # Split to dictionary
        temporal_revert_dict = {}
        for n, col in enumerate(temporal_cols):
            temporal_revert_dict[col] = temporal[:, :, n, :]

        return temporal_revert_dict

class ImgRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, img_dict, idx_dict, img_cols):
        img_revert_dict = {}
        for col in img_cols:
            img_data = img_dict[col]
            
            # Split global token
            global_token = img_data[:, :1, :]
            img_data = img_data[:, 1:, :]

            # Append mask token
            revert_idx = idx_dict[f"{col}_revert_idx"]
            mask_token = self.mask_token.unsqueeze(0).repeat(img_data.shape[0], revert_idx.shape[1]-img_data.shape[-2], 1)
            img_data = torch.cat([img_data, mask_token], dim=-2)

            # Apply revert
            img_data = torch.gather(img_data, index=revert_idx.unsqueeze(-1).repeat(1, 1, img_data.shape[-1]), dim=-2)
            img_data = torch.cat([global_token, img_data], dim=1)

            img_revert_dict[col] = img_data

        return img_revert_dict

class NlpRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, nlp_dict, idx_dict, nlp_cols):
        nlp_revert_dict = {}
        for col in nlp_cols:
            nlp_data = nlp_dict[col]
            
            # Append mask token
            revert_idx = idx_dict[f"{col}_revert_idx"]
            mask_token = self.mask_token.unsqueeze(0).repeat(nlp_data.shape[0], revert_idx.shape[1]-nlp_data.shape[-2], 1)
            nlp_data = torch.cat([nlp_data, mask_token], dim=-2)

            # Apply revert
            nlp_data = torch.gather(nlp_data, index=revert_idx.unsqueeze(-1).repeat(1, 1, nlp_data.shape[-1]), dim=-2)

            nlp_revert_dict[col] = nlp_data
       
        return nlp_revert_dict


class TemporalDecoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
    
    def forward(self):
        return
1==1

True

In [91]:
class Embedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.modality_info["target"] + data_info.modality_info["temporal"]:
            self.embedding = TemporalEmbedding(col, data_info, label_encoder_dict, d_model)
        elif col in data_info.modality_info["img"]:
            self.embedding = ImgEmbedding(d_model)
        elif col in data_info.modality_info["nlp"]:
            self.embedding = NlpEmbedding(d_model)
    
    def forward(self, key, val, padding_mask, device):
        return self.embedding(key, val, padding_mask, device)

class PosModEncoding(torch.nn.Module):
    def __init__(self, col, pos_enc, num_modality, modality, d_model):
        super().__init__()
        self.pos_enc = pos_enc
        self.modality = modality[col]
        self.modality_embedding = torch.nn.Embedding(num_modality, d_model)

    def forward(self, key, val, device):
        # Positional encoding
        val = self.pos_enc(val)

        # Modality embedding
        modality = torch.zeros(val.shape[1]).to(torch.int).to(device) + self.modality
        modality = self.modality_embedding(modality)

        return val + modality

class Remain(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.temporal_remain = TemporalRemain()
        self.img_remain = ImgRemain()
        self.nlp_remain = NlpRemain()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, temporal_cols, img_cols, nlp_cols, device):
        temporal_dict, temporal_idx_dict, temporal_padding_mask_dict = self.temporal_remain(data_dict, padding_mask_dict, remain_rto["temporal"], temporal_cols, device)
        idx_dict.update(temporal_idx_dict)
        padding_mask_dict.update(temporal_padding_mask_dict)

        img_dict, img_idx_dict, img_padding_mask_dict = self.img_remain(data_dict, remain_rto["img"], img_cols, device)
        idx_dict.update(img_idx_dict)
        padding_mask_dict.update(img_padding_mask_dict)

        nlp_dict = self.nlp_remain(data_dict, idx_dict, remain_rto["nlp"], nlp_cols, device)

        return temporal_dict, img_dict, nlp_dict, idx_dict, padding_mask_dict

class Encoder(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation, num_layers, to_decoder_dim):
        super().__init__()
        self.encoder_layers = torch.nn.ModuleList([copy.deepcopy(EncoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
        self.to_decoder_dim = to_decoder_dim
    
    def forward(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask, img_idx_li, nlp_idx_li = self.get_src(temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device)
        for mod in self.encoder_layers:
            src = mod(src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask)
        
        encoded = self.to_decoder_dim(src)
        temporal_dict, img_dict, nlp_dict = self.undo_src(encoded, temporal_shape, temporal_idx, img_idx_li, nlp_idx_li, img_cols, nlp_cols)

        return temporal_dict, img_dict, nlp_dict
    
    def get_src(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        # Temporal
        temporal = temporal_dict["temporal"]
        temporal_shape = temporal.shape
        batch_size, seq_len, _, d_model = temporal_shape
        temporal = temporal.view(batch_size, -1, d_model)
        temporal_padding_mask = padding_mask_dict["temporal_remain_padding_mask"]
        temporal_idx = torch.arange(0, temporal.shape[1]).to(device)

        # Static
        static_start_idx = temporal_idx[-1]+1
        idx = static_start_idx.clone()

        ### Img
        img_li, img_padding_mask_li, img_idx_li = [], [], []
        for col in img_cols:
            img_data = img_dict[col]
            img_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]

            img_li.append(img_data)
            img_padding_mask_li.append(img_padding_mask)
            img_idx_li.append(torch.arange(idx, idx+img_data.shape[1]).to(device))
            idx += img_data.shape[1]
        
        img = torch.cat(img_li, dim=1)
        img_padding_mask = torch.cat(img_padding_mask_li, dim=1)

        ### Nlp
        nlp_li, nlp_padding_mask_li, nlp_idx_li = [], [], []
        for col in nlp_cols:
            nlp_data = nlp_dict[col]
            nlp_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]

            nlp_li.append(nlp_data)
            nlp_padding_mask_li.append(nlp_padding_mask)
            nlp_idx_li.append(torch.arange(idx, idx+nlp_data.shape[1]).to(device))
            idx += nlp_data.shape[1]

        nlp = torch.cat(nlp_li, dim=1)
        nlp_padding_mask = torch.cat(nlp_padding_mask_li, dim=1)

        ### Src
        static = torch.cat([img, nlp], dim=1)
        static_padding_mask = torch.cat([img_padding_mask, nlp_padding_mask], dim=1)

        img_idx = torch.cat(img_idx_li)
        nlp_idx = torch.cat(nlp_idx_li)
        static_idx = torch.cat([img_idx, nlp_idx])

        # Src
        src = torch.cat([temporal, static], dim=1)
        
        return src, temporal_shape, temporal_idx, static_idx, temporal_padding_mask, static_padding_mask, img_idx_li, nlp_idx_li
    
    def undo_src(self, src, temporal_shape, temporal_idx, img_idx_li, nlp_idx_li, img_cols, nlp_cols):
        # Temporal
        temporal_dict = {}
        temporal_idx = temporal_idx.unsqueeze(0).unsqueeze(-1).repeat(src.shape[0], 1, src.shape[-1])
        temporal = torch.gather(src, index=temporal_idx, dim=1).view(temporal_shape[0], temporal_shape[1], temporal_shape[2], -1)
        temporal_dict["temporal"] = temporal

        # Img
        img_dict = {}
        for col, idx in zip(img_cols, img_idx_li):
            idx = idx.unsqueeze(0).unsqueeze(-1).repeat(src.shape[0], 1, src.shape[-1])
            img_dict[col] = torch.gather(src, index=idx, dim=1)

        # Nlp
        nlp_dict = {}
        for col, idx in zip(nlp_cols, nlp_idx_li):
            idx = idx.unsqueeze(0).unsqueeze(-1).repeat(src.shape[0], 1, src.shape[-1])
            nlp_dict[col] = torch.gather(src, index=idx, dim=1)
        
        return temporal_dict, img_dict, nlp_dict

class Revert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.temporal_revert = TemporalRevert(mask_token)
        self.img_revert = ImgRevert(mask_token)
        self.nlp_revert = NlpRevert(mask_token)
    
    def forward(self, temporal_dict, img_dict, nlp_dict, idx_dict, temporal_cols, img_cols, nlp_cols):
        temporal_dict = self.temporal_revert(temporal_dict["temporal"], idx_dict, temporal_cols)
        img_dict = self.img_revert(img_dict, idx_dict, img_cols)
        nlp_dict = self.nlp_revert(nlp_dict, idx_dict, nlp_cols)

        revert_dict = {}
        revert_dict.update(temporal_dict)
        revert_dict.update(img_dict)
        revert_dict.update(nlp_dict)
        
        return revert_dict

class Decoder(torch.nn.Module):
    def __init__(self, col, d_model, nhead, d_ff, dropout, activation, num_layers):
        super().__init__()
        self.temporal_decoder_layers = torch.nn.ModuleList([copy.deepcopy(TemporalDecoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
    
    def forward(self, key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device):
        # Temporal
        if key in temporal_cols:
            tgt = val
            memory = self.get_temporal_tgt_memory(key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device)
            
        return

    def get_temporal_tgt_memory(self, key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device):
        memory_li = []

        # Temporal memory
        for col in temporal_cols:
            if col == key:
                continue
            memory_li.append(data_dict[col])
        memory = torch.stack(memory_li, dim=-2)
        cross_attn_padding_mask = torch.ones(memory.shape[:-1]).to(device)
        
        # Static memory
        for col in img_cols + nlp_cols:
            static_data = data_dict[col].unsqueeze(1).repeat(1, memory.shape[1], 1, 1)
            static_padding_msak = padding_mask_dict[f"{col}_revert_padding_mask"].unsqueeze(1).repeat(1, memory.shape[1], 1)

            memory = torch.cat([memory, static_data], dim=-2)
            cross_attn_padding_mask = torch.cat([cross_attn_padding_mask, static_padding_msak], dim=-1)
        
        # Self attn padding mask
        self_attn_padding_mask = padding_mask_dict["temporal_padding_mask"]
        print(self_attn_padding_mask.shape)

        raise


1==1

True

In [92]:
class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict,
                d_model, num_layers, nhead, d_ff, dropout, activation):
        super().__init__()
        self.data_info, self.label_encoder_dict = data_info, label_encoder_dict
        self.temporal_cols, self.img_cols, self.nlp_cols, self.total_cols = self.define_cols()
        self.self_attn_weight_dict, self.cross_attn_weight_dict = {}, {}
        
        # 1. Embedding
        self.embedding_dict = self.init_process(mod=Embedding, args=[self.data_info, self.label_encoder_dict, d_model["encoder"]])
        # 2. Pos encoding and modality embedding
        num_modality = len(self.total_cols)
        modality = {col:n for n, col in enumerate(self.total_cols)}
        encoder_pos_enc = PositionalEncoding(d_model["encoder"], dropout)
        self.enc_pos_mod_encoding_dict = self.init_process(mod=PosModEncoding, args=[encoder_pos_enc, num_modality, modality, d_model["encoder"]])
        # 3. Remain mask
        self.remain_dict = Remain()
        # 4. Encoding
        to_decoder_dim = torch.nn.Linear(d_model["encoder"], d_model["decoder"])
        self.encoding = Encoder(d_model["encoder"], nhead, d_ff["encoder"], dropout, activation, num_layers["encoder"], to_decoder_dim)
        # 5. Revert
        mask_token = torch.nn.Parameter(torch.rand(1, d_model["decoder"]))
        self.revert = Revert(mask_token)
        # 6. Pos encoding and modality embedding
        decoder_pos_enc = PositionalEncoding(d_model["decoder"], dropout)
        self.dec_pos_mod_encoding_dict = self.init_process(mod=PosModEncoding, args=[decoder_pos_enc, num_modality, modality, d_model["decoder"]])
        # 7. Decoding
        self.decoding = self.init_process(mod=Decoder, args=[d_model["decoder"], nhead, d_ff["decoder"], dropout, activation, num_layers["decoder"]])

    def forward(self, data_input, remain_rto, device):
        data_dict, self.idx_dict, self.padding_mask_dict = self.to_gpu(data_input, device)
        
        # 1. Embedding
        embedding_dict = self.apply_process(data=data_dict, mod=self.embedding_dict, args=[self.padding_mask_dict, device])
        # 2. Pos encoding and modality embedding
        enc_pos_mod_encoding_dict = self.apply_process(data=embedding_dict, mod=self.enc_pos_mod_encoding_dict, args=[device])
        # 3. Remain mask
        temporal_dict, img_dict, nlp_dict, self.idx_dict, self.padding_mask_dict = self.remain_dict(enc_pos_mod_encoding_dict, self.idx_dict, self.padding_mask_dict, remain_rto, self.temporal_cols, self.img_cols, self.nlp_cols, device)
        # 4. Encoding
        temporal_encoding_dict, img_encoding_dict, nlp_encoding_dict = self.encoding(temporal_dict, img_dict, nlp_dict, self.padding_mask_dict, self.img_cols, self.nlp_cols, device)
        # 5. Revert
        revert_dict = self.revert(temporal_encoding_dict, img_encoding_dict, nlp_encoding_dict, self.idx_dict, self.temporal_cols, self.img_cols, self.nlp_cols)
        # 6. Pos encoding and modality embedding
        dec_pos_mod_encoding_dict = self.apply_process(data=revert_dict, mod=self.dec_pos_mod_encoding_dict, args=[device])
        # 7. Decoder
        self.apply_process(data=dec_pos_mod_encoding_dict, mod=self.decoding, args=[dec_pos_mod_encoding_dict, self.padding_mask_dict, self.temporal_cols, self.img_cols, self.nlp_cols, device])



    
    def define_cols(self):
        temporal_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        img_cols = self.data_info.modality_info["img"]
        nlp_cols = self.data_info.modality_info["nlp"]
        total_cols = temporal_cols + img_cols + nlp_cols

        return temporal_cols, img_cols, nlp_cols, total_cols

    def to_gpu(self, data_input, device):
        data_dict, idx_dict, padding_mask_dict = {}, {}, {}
        for key, val in data_input.items():
            if key in self.temporal_cols + self.img_cols + self.nlp_cols:
                data_dict[key] = data_input[key].to(device)
            elif key.endswith("idx"):
                idx_dict[key] = data_input[key].to(device)
            elif key.endswith("padding_mask"):
                padding_mask_dict[key] = data_input[key].to(device)
            
        return data_dict, idx_dict, padding_mask_dict

    def init_process(self, mod, args=[], target_cols=None):
        result_dict = {}
        target_cols = self.total_cols if target_cols is None else target_cols
        for col in target_cols:
            result_dict[col] = mod(col, *args)
        
        return torch.nn.ModuleDict(result_dict)

    def apply_process(self, data, mod, args=[], target_cols=None, collate_fn=None):
        result_dict = {}
        target_cols = self.total_cols if target_cols is None else target_cols
        for col in target_cols:
            result_dict[col] = mod[col](col, data[col], *args)
        
        if collate_fn is not None:
            return collate_fn(result_dict)
        else:
            return result_dict

    def tidy_decoding(self, decoding_dict):
        result_dict = {}
        for key, val in decoding_dict.items():
            result_dict[key], self_attn_weight, cross_attn_weight = val
            self.self_attn_weight_dict.update({key:self_attn_weight})
            self.cross_attn_weight_dict.update({key:cross_attn_weight})

        return result_dict
1==1

True

In [93]:
model = Transformer(data_info, train_dataset.label_encoder_dict,
                        d_model, num_layers, nhead, d_ff, dropout, "gelu")
model.to(device)
summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

torch.Size([2, 100, 1])


RuntimeError: No active exception to reraise

# Train

In [None]:
for name, param in model.named_parameters():
    if "img_model" in name:
        param.requires_grad = False
    elif "nlp_model" in name:
        param.requires_grad = False

In [None]:
import warnings
warnings.filterwarnings("ignore")

def patchify(imgs, patch_size=16):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = patch_size
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
    return x

def unpatchify(x, patch_size=16):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = patch_size
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs