In [1]:
import torch
# Reciporal gather
arr = torch.arange(0, 8).unsqueeze(0).repeat(4, 1)
idx = torch.tensor([0,0,1,1]).unsqueeze(-1)
print(arr.shape)
print(idx.shape)

m = torch.ones(arr.shape).scatter(1, idx, 0)
m = m.nonzero()[:, 1].reshape(-1, arr.size(1) - idx.size(1))

torch.gather(arr, index=m, dim=-1)

torch.Size([4, 8])
torch.Size([4, 1])


tensor([[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4, 5, 6, 7],
        [0, 2, 3, 4, 5, 6, 7],
        [0, 2, 3, 4, 5, 6, 7]])

### Import

In [2]:
import copy

from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary
from transformers import ViTModel, AutoImageProcessor, BertModel, AutoTokenizer

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

from architecture import Transformer
from architecture_detail import *

import cv2

device = torch.device("cuda")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


### Config

In [3]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 10
MAX_SEQ_LEN = 50
PRED_LEN = 100

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "img": ["img_path"],
    "nlp": ["detail_desc"]
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
    "img_cols": ["img_path"],
    "nlp_cols": ["detail_desc"]
}

# Model
batch_size = 32
nhead = 4
dropout = 0.1
patch_size = 16

d_model = {"encoder":256, "decoder":128}
d_ff = {"encoder":256, "decoder":128}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"target":0.25, "temporal":0.25, "img":0.25, "nlp":0.25}

# Data

### Raw data

In [4]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [5]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_train = df_train[~pd.isna(df_train["detail_desc"])]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]
df_valid = df_valid[~pd.isna(df_valid["detail_desc"])]

data_info = DataInfo(modality_info, processing_info)

In [6]:
train_dataset = Dataset(df_train, data_info, remain_rto)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key and "raw" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 2799.00it/s]


sales torch.Size([2, 50, 1])
day torch.Size([2, 50])
dow torch.Size([2, 50])
month torch.Size([2, 50])
holiday torch.Size([2, 50])
price torch.Size([2, 50, 1])
target_fcst_mask torch.Size([2, 50, 1])
temporal_padding_mask torch.Size([2, 50])
img_path torch.Size([2, 3, 224, 224])
detail_desc torch.Size([2, 9])
detail_desc_remain_idx torch.Size([2, 2])
detail_desc_masked_idx torch.Size([2, 6])
detail_desc_revert_idx torch.Size([2, 8])
detail_desc_remain_padding_mask torch.Size([2, 2])
detail_desc_masked_padding_mask torch.Size([2, 6])
detail_desc_revert_padding_mask torch.Size([2, 8])


# Architecture

In [7]:
arr = torch.arange(0, 5).unsqueeze(0).repeat(5, 1)
print(arr); print("_"*100)

noise = torch.rand(5,5)
shuffle_idx = torch.argsort(noise, dim=-1)
revert_idx = torch.argsort(shuffle_idx, dim=-1)

remain = torch.gather(arr, index=shuffle_idx, dim=-1)
print(remain); print("_"*100)

revert = torch.gather(remain, index=revert_idx, dim=-1)
print(revert)


tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])
____________________________________________________________________________________________________
tensor([[2, 4, 0, 3, 1],
        [0, 2, 3, 4, 1],
        [4, 2, 0, 1, 3],
        [2, 1, 0, 4, 3],
        [0, 4, 3, 2, 1]])
____________________________________________________________________________________________________
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])


In [8]:
class MultiheadBlockSelfAttention(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.nhead = nhead

        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, query, key, value, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Split head
        batch_size, seq_len, d_model = Q.shape
        Q = Q.view(batch_size, -1, self.nhead, d_model//self.nhead).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.nhead, d_model//self.nhead).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.nhead, d_model//self.nhead).permute(0, 2, 1, 3)

        temporal_attn_output, temporal_attn_weight = self.temporal_side_attention(Q, K, V, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask)
        static_attn_output, static_attn_weight = self.static_side_attention(Q, K, V, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask)

        # Concat all
        temporal_attn_output = temporal_attn_output.view(temporal_attn_output.shape[0], -1, temporal_attn_output.shape[-1])
        attn_output = torch.cat([temporal_attn_output, static_attn_output], dim=1)

        return attn_output
    
    def temporal_side_attention(self, Q, K, V, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask):
        # Obtain QKV
        batch_size, seq_len, num_modality, _ = temporal_block_shape
        temporal_unblock_idx = temporal_unblock_idx.unsqueeze(0).unsqueeze(1).unsqueeze(-1).repeat(Q.shape[0], Q.shape[1], 1, Q.shape[-1])
        static_unblock_idx = static_unblock_idx.unsqueeze(0).unsqueeze(1).unsqueeze(-1).repeat(Q.shape[0], Q.shape[1], 1, Q.shape[-1])
        
        Qt = torch.gather(Q, index=temporal_unblock_idx, dim=-2).view(batch_size, Q.shape[1], seq_len, num_modality, Q.shape[-1])
        Kt = torch.gather(K, index=temporal_unblock_idx, dim=-2).view(batch_size, K.shape[1], seq_len, num_modality, K.shape[-1])
        Ks = torch.gather(K, index=static_unblock_idx, dim=-2).unsqueeze(-3)
        Vt = torch.gather(V, index=temporal_unblock_idx, dim=-2).view(batch_size, Q.shape[1], seq_len, num_modality, Q.shape[-1])
        Vs = torch.gather(V, index=static_unblock_idx, dim=-2).unsqueeze(-3)
        
        # Scaled dot product attention
        ### 1. Q @ K^t
        QtKt = Qt @ Kt.permute(0,1,2,4,3)
        QtKs = Qt @ Ks.permute(0,1,2,4,3)
        QK = torch.cat([QtKt, QtKs], dim=-1)
        logits = QK / math.sqrt(QK.shape[-1]//self.nhead)
        
        ### 2. Padding mask
        static_unblock_padding_mask = static_unblock_padding_mask.unsqueeze(1).repeat(1, temporal_block_padding_mask.shape[1], 1)
        key_padding_mask = torch.cat([temporal_block_padding_mask, static_unblock_padding_mask], dim=-1)
        key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(-2).repeat(1, logits.shape[1], 1, logits.shape[-2], 1)
        logits += key_padding_mask

        ### 3. Softmax
        attn_weight = torch.nn.functional.softmax(logits, dim=-1)
        QtKt_attn_weight = attn_weight[:, :, :, :, :QtKt.shape[-1]]
        QtKs_attn_weight = attn_weight[:, :, :, :, QtKt.shape[-1]:]

        ### 4. Matmul V
        QtKtVt = QtKt_attn_weight @ Vt
        QtKsVs = QtKs_attn_weight @ Vs

        temporal_attn_output = QtKtVt + QtKsVs
        
        ### 5. Concat heads
        temporal_attn_output = temporal_attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, num_modality, -1)

        return temporal_attn_output, attn_weight

    def static_side_attention(self, Q, K, V, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask):
        # Obtain QKV
        batch_size, seq_len, num_modality, _ = temporal_block_shape
        temporal_unblock_idx = temporal_unblock_idx.unsqueeze(0).unsqueeze(1).unsqueeze(-1).repeat(Q.shape[0], Q.shape[1], 1, Q.shape[-1])
        static_unblock_idx = static_unblock_idx.unsqueeze(0).unsqueeze(1).unsqueeze(-1).repeat(Q.shape[0], Q.shape[1], 1, Q.shape[-1])

        Qs = torch.gather(Q, index=static_unblock_idx, dim=-2)
        
        # Scaled dot product attention
        ### 1. Q @ K^t
        QK = Qs @ K.permute(0,1,3,2)
        logits = QK / math.sqrt(QK.shape[-1]//self.nhead)

        ### 2. Padding mask
        temporal_block_padding_mask = temporal_block_padding_mask.reshape(temporal_block_padding_mask.shape[0], -1)
        key_padding_mask = torch.cat([temporal_block_padding_mask, static_unblock_padding_mask], dim=-1)
        key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(-2).repeat(1, QK.shape[1], QK.shape[-2], 1)
        logits += key_padding_mask

        ### 3. Softmax
        attn_weight = torch.nn.functional.softmax(logits, dim=-1)

        ### 4. Matmul V
        attn_output = attn_weight @ V
        
        ### 5. Concat heads
        attn_output = attn_output.permute(0,2,1,3).reshape(batch_size, Qs.shape[-2], -1)

        return attn_output, attn_weight


class MultiheadBlockAttention(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.nhead = nhead

        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, query, key, value, padding_mask):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Split head
        batch_size, seq_len, _, d_model = Q.shape
        Q = Q.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        K = K.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        V = V.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)

        # Scaled dot product attention
        ### 1. Q·K^t
        QK = Q @ K.permute(0,1,2,4,3)
        logits = QK / math.sqrt(d_model//self.nhead)
        
        ### #. Padding_mask
        padding_mask = padding_mask.unsqueeze(1).unsqueeze(-2).repeat(1, logits.shape[1], 1, logits.shape[-2], 1)
        logits += padding_mask

        ### 2. Softmax
        attn_weight = torch.nn.functional.softmax(logits, dim=-1)

        ### 3. Matmul V
        attn_output = attn_weight @ V

        ### 4. Concat heads
        attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, -1, d_model)

        return attn_output, attn_weight

1==1

True

In [31]:
class TemporalEmbedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.embedding = torch.nn.Linear(1, d_model)
        
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.embedding = torch.nn.Embedding(num_cls, d_model)
        
    def forward(self, key, val, padding_mask_dict, device):
        embedding = self.embedding(val)
        return embedding

class ImgEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.downsize_linear = torch.nn.Linear(768, d_model)
    
    def forward(self, key, val, padding_mask_dict, device):
        embedding = self.img_model(val).last_hidden_state
        embedding = self.downsize_linear(embedding)
        return embedding

class NlpEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.nlp_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
        self.downsize_linear = torch.nn.Linear(768, d_model)   
    
    def forward(self, key, val, padding_mask_dict, device):
        # Make token_type_ids
        token_type_ids = torch.zeros(val.shape).to(torch.int).to(device)

        # Make attention mask
        attention_mask = padding_mask_dict[f"{key}_revert_padding_mask"]
        mask_for_global_token = torch.ones(attention_mask.shape[0], 1).to(device)
        attention_mask = torch.cat([attention_mask, mask_for_global_token], dim=-1)

        # Embed data
        inputs = {"input_ids":val, "token_type_ids":token_type_ids, "attention_mask":attention_mask}
        embedding = self.nlp_model(**inputs).last_hidden_state
        embedding = self.downsize_linear(embedding)
        
        return embedding


class TemporalRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, temporal_cols, device):
        # Concat_data
        concat_data_li, concat_padding_mask_li = [], []
        for col in temporal_cols:
            temporal_data = data_dict[col]
            temporal_padding_mask = padding_mask_dict["temporal_padding_mask"]
            
            concat_data_li.append(temporal_data)
            concat_padding_mask_li.append(temporal_padding_mask)
        
        concat_data = torch.stack(concat_data_li, dim=-2) # Block shaped
        concat_padding_mask = torch.stack(concat_padding_mask_li, dim=-1) # Block shaped with all 1s in terms of -1 dimension

        # Split token and valid data
        global_token_data = concat_data[:, :, :1, :]
        global_token_padding_mask = concat_padding_mask[:, :, :1]

        valid_data = concat_data[:, :, 1:, :]
        valid_padding_mask = concat_padding_mask[:, :, 1:]

        # Remain masking for valid data
        num_modality = valid_data.shape[-2]
        num_remain = int(num_modality * remain_rto)

        noise = torch.rand(valid_data.shape[:-1]).to(device)
        shuffle_idx = torch.argsort(noise, dim=-1)

        remain_idx = shuffle_idx[:, :, :num_remain]
        masked_idx = shuffle_idx[:, :, num_remain:]
        revert_idx = torch.argsort(shuffle_idx, dim=-1)

        # Apply mask
        valid_remain_data = torch.gather(valid_data, index=remain_idx.unsqueeze(-1).repeat(1, 1, 1, valid_data.shape[-1]), dim=-2)
        valid_remain_padding_mask = torch.gather(valid_padding_mask, index=remain_idx, dim=-1)

        concat_remain_data = torch.cat([global_token_data, valid_remain_data], dim=-2)
        concat_remain_padding_mask = torch.cat([global_token_padding_mask, valid_remain_padding_mask], dim=-1)
        
        # Obtain revert padding mask
        revert_padding_mask = torch.ones(revert_idx.shape[0], revert_idx.shape[1], revert_idx.shape[2]+1).to(device)

        # Finalize
        result_dict = {"temporal_remain": concat_remain_data}
        idx_dict.update({"temporal_remain_idx": remain_idx, "temporal_masked_idx":masked_idx, "temporal_revert_idx":revert_idx})
        padding_mask_dict.update({"temporal_remain_padding_mask":concat_remain_padding_mask, "temporal_revert_padding_mask": revert_padding_mask})
        return result_dict, idx_dict, padding_mask_dict

class ImgRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, img_cols, device):
        result_dict = {}
        # Get indexs
        for col in img_cols:
            val = data_dict[col]
            
            # # Add global token
            # global_token = self.global_token.unsqueeze(0).repeat(val.shape[0], 1, 1)
            # val = torch.cat([global_token, val], dim=1)
            
            # Split global token
            global_token = val[:, :1, :]
            val = val[:, 1:, :]

            # Apply remain
            num_remain = int(val.shape[1] * remain_rto)
            noise = torch.rand(val.shape[0], val.shape[1]).to(device)
            shuffle_idx = torch.argsort(noise, dim=1)

            remain_idx = shuffle_idx[:, :num_remain]
            masked_idx = shuffle_idx[:, num_remain:]
            revert_idx = torch.argsort(shuffle_idx, dim=1)

            remain_padding_mask = torch.ones(remain_idx.shape[0], remain_idx.shape[1]+1).to(device)
            revert_padding_mask = torch.ones(revert_idx.shape[0], revert_idx.shape[1]+1).to(device)

            # Apply mask
            val = torch.gather(val, index=remain_idx.unsqueeze(-1).repeat(1, 1, val.shape[-1]), dim=1)
            val = torch.cat([global_token, val], dim=1)

            result_dict[col] = val
            idx_dict.update({f"{col}_remain_idx":remain_idx, f"{col}_masked_idx":masked_idx, f"{col}_revert_idx":revert_idx})
            padding_mask_dict.update({f"{col}_remain_padding_mask":remain_padding_mask, f"{col}_revert_padding_mask":revert_padding_mask})

        return result_dict, idx_dict, padding_mask_dict

class NlpRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, nlp_cols, device):
        result_dict = {}
        for col in nlp_cols:
            val = data_dict[col]
            
            # Split global token
            global_token = val[:, :1, :]
            val = val[:, 1:, :]

            # Apply remain mask
            remain_idx = idx_dict[f"{col}_remain_idx"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            val = torch.gather(val, index=remain_idx, dim=1)
            val = torch.cat([global_token, val], dim=1)
            result_dict[col] = val

            # Update padding_mask
            padding_mask_for_global_token = torch.ones(val.shape[0], 1).to(device)

            remain_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]
            remain_padding_mask = torch.cat([padding_mask_for_global_token, remain_padding_mask], dim=-1)
            padding_mask_dict[f"{col}_remain_padding_mask"] = remain_padding_mask

            revert_padding_mask = padding_mask_dict[f"{col}_revert_padding_mask"]
            revert_padding_mask = torch.cat([padding_mask_for_global_token, revert_padding_mask], dim=-1)
            padding_mask_dict[f"{col}_revert_padding_mask"] = revert_padding_mask
        return result_dict, padding_mask_dict


class EncoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
        self.self_attn = MultiheadBlockSelfAttention(d_model, nhead, dropout)
        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.dropout1 = torch.nn.Dropout(dropout)
        
        # Feed forward
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()
        
        self.linear_ff1 = torch.nn.Linear(d_model, d_ff)
        self.linear_ff2 = torch.nn.Linear(d_ff, d_model)
        self.dropout_ff1 = torch.nn.Dropout(dropout)
        self.dropout_ff2 = torch.nn.Dropout(dropout)
    
    def forward(self, src, temporal_block_shape, 
                        temporal_unblock_idx, static_unblock_idx,
                        temporal_block_padding_mask, static_unblock_padding_mask):
        x = src
        attn_output = self._sa_block(self.norm1(x), temporal_block_shape,
                                        temporal_unblock_idx, static_unblock_idx,
                                        temporal_block_padding_mask, static_unblock_padding_mask)
        x = x + attn_output

        x = x + self._ff_block(self.norm2(x))
        return x
    
    def _sa_block(self, src, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask):
        x = self.self_attn(src, src, src, temporal_block_shape, temporal_unblock_idx, static_unblock_idx, temporal_block_padding_mask, static_unblock_padding_mask)
        return self.dropout1(x)

    def _ff_block(self, x):
        x = self.linear_ff2(self.dropout_ff1(self.activation(self.linear_ff1(x))))
        return self.dropout_ff2(x)


class TemporalRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, temporal_data, idx_dict, temporal_cols_g):
        # Split global token and valid data
        global_token_seq = temporal_data[:, :, :1, :]
        valid_seq = temporal_data[:, :, 1:, :]

        # Append mask_token
        revert_idx = idx_dict["temporal_revert_idx"]
        mask_token = self.mask_token.unsqueeze(0).unsqueeze(1).repeat(valid_seq.shape[0], valid_seq.shape[1], revert_idx.shape[-1]-valid_seq.shape[-2], 1)
        valid_seq = torch.cat([valid_seq, mask_token], dim=-2)
        
        # Apply revert
        revert_idx = revert_idx.unsqueeze(-1).repeat(1, 1, 1, valid_seq.shape[-1])
        valid_seq = torch.gather(valid_seq, index=revert_idx, dim=-2)

        # Concat global token
        temporal_revert_data = torch.cat([global_token_seq, valid_seq], dim=-2)
        
        # Split to dictionary
        temporal_revert_dict = {}
        for n, col in enumerate(temporal_cols_g):
            temporal_revert_dict[col] = temporal_revert_data[:, :, n, :]
        
        return temporal_revert_dict

class ImgRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, img_dict, idx_dict, img_cols):
        img_revert_dict = {}
        for col in img_cols:
            img_data = img_dict[col]

            # Split global token
            global_token = img_data[:, :1, :]
            valid_data = img_data[:, 1:, :]

            # Append mask token
            revert_idx = idx_dict[f"{col}_revert_idx"]
            mask_token = self.mask_token.unsqueeze(0).repeat(valid_data.shape[0], revert_idx.shape[1]-valid_data.shape[1], 1)
            valid_data = torch.cat([valid_data, mask_token], dim=-2)

            # Apply revert
            revert_idx = revert_idx.unsqueeze(-1).repeat(1, 1, valid_data.shape[-1])
            valid_data = torch.gather(valid_data, index=revert_idx, dim=-2)
            
            # Append global token
            img_revert = torch.cat([global_token, valid_data], dim=1)
            img_revert_dict[col] = img_revert

        return img_revert_dict

class NlpRevert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.mask_token = mask_token
    
    def forward(self, nlp_dict, idx_dict, padding_mask_dict, nlp_cols):
        nlp_revert_dict = {}
        for col in nlp_cols:
            nlp_data = nlp_dict[col]

            # Replace padding mask to mask token
            remain_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"].unsqueeze(-1).repeat(1, 1, nlp_data.shape[-1])
            nlp_data = torch.where(remain_padding_mask==1, nlp_data, self.mask_token)

            # Split global token
            global_token = nlp_data[:, :1, :]
            valid_data = nlp_data[:, 1:, :]

            # Append mask token
            revert_idx = idx_dict[f"{col}_revert_idx"]
            mask_token = self.mask_token.unsqueeze(0).repeat(valid_data.shape[0], revert_idx.shape[1]-valid_data.shape[1], 1)
            valid_data = torch.cat([valid_data, mask_token], dim=-2)

            # Apply revert
            revert_idx = revert_idx.unsqueeze(-1).repeat(1, 1, valid_data.shape[-1])
            valid_data = torch.gather(valid_data, index=revert_idx, dim=-2)
            
            # Append global token
            nlp_revert = torch.cat([global_token, valid_data], dim=1)
            nlp_revert_dict[col] = nlp_revert

        return nlp_revert_dict


class TemporalDecoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
        self.cross_attn = MultiheadBlockAttention(d_model, nhead, dropout)
        self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.norm3 = torch.nn.LayerNorm(d_model)

        self.dropout1 = torch.nn.Dropout(dropout)
        self.dropout2 = torch.nn.Dropout(dropout)

        # Feed forward
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()
        
        self.linear_ff1 = torch.nn.Linear(d_model, d_ff)
        self.linear_ff2 = torch.nn.Linear(d_ff, d_model)
        self.dropout_ff1 = torch.nn.Dropout(dropout)
        self.dropout_ff2 = torch.nn.Dropout(dropout)
    
    def forward(self, tgt, memory, cross_attn_padding_mask, self_attn_padding_mask):
        x = tgt

        self_attn_output, self_attn_weight = self._sa_block(self.norm2(x), self_attn_padding_mask)
        x = x + self_attn_output

        cross_attn_output, cross_attn_weight = self._ca_block(self.norm1(x.unsqueeze(-2)), memory, memory, cross_attn_padding_mask)
        x = x + cross_attn_output.squeeze()

        x = x + self._ff_block(self.norm3(x))
        return x, self_attn_weight, cross_attn_weight

    def _ca_block(self, query, key, value, padding_mask):
        x, attn_weight = self.cross_attn(query, key, value, padding_mask)
        return self.dropout1(x), attn_weight

    def _sa_block(self, src, padding_mask):
        padding_mask = padding_mask[:, :, 0].squeeze()
        x, attn_weight = self.self_attn(src, src, src, key_padding_mask=padding_mask, average_attn_weights=False)
        return self.dropout2(x), attn_weight

    def _ff_block(self, x):
        x = self.linear_ff2(self.dropout_ff1(self.activation(self.linear_ff1(x))))
        return self.dropout_ff2(x)

class NonTemporalDecoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
        self.cross_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
        self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)

        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.norm3 = torch.nn.LayerNorm(d_model)

        self.dropout1 = torch.nn.Dropout(dropout)
        self.dropout2 = torch.nn.Dropout(dropout)

        
        # Feed forward
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()
        
        self.linear_ff1 = torch.nn.Linear(d_model, d_ff)
        self.linear_ff2 = torch.nn.Linear(d_ff, d_model)
        self.dropout_ff1 = torch.nn.Dropout(dropout)
        self.dropout_ff2 = torch.nn.Dropout(dropout)
    
    def forward(self, tgt, memory, cross_attn_padding_mask, self_attn_padding_mask):
        x = tgt

        self_attn_output, self_attn_weight = self._sa_block(self.norm2(x), self_attn_padding_mask)
        x = x + self_attn_output

        cross_attn_output, cross_attn_weight = self._ca_block(self.norm1(x), memory, memory, cross_attn_padding_mask)
        x = x + cross_attn_output

        x = x + self._ff_block(self.norm3(x))

    
        return x, self_attn_weight, cross_attn_weight
    
    def _ca_block(self, query, key, value, padding_mask):
        x, attn_weight = self.cross_attn(query, key, value, key_padding_mask=padding_mask)
        return self.dropout1(x), attn_weight
    
    def _sa_block(self, src, padding_mask):
        x, attn_weight = self.self_attn(src, src, src, key_padding_mask=padding_mask)
        return self.dropout2(x), attn_weight

    def _ff_block(self, x):
        x = self.linear_ff2(self.dropout_ff1(self.activation(self.linear_ff1(x))))
        return self.dropout_ff2(x)

1==1

True

In [35]:
class Class(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self):
        return

class Embedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.modality_info["target"] + data_info.modality_info["temporal"]:
            self.embedding = TemporalEmbedding(col, data_info, label_encoder_dict, d_model)
        elif col in data_info.modality_info["img"]:
            self.embedding = ImgEmbedding(d_model)
        elif col in data_info.modality_info["nlp"]:
            self.embedding = NlpEmbedding(d_model)
    
    def forward(self, key, val, padding_mask, device):
        return self.embedding(key, val, padding_mask, device)

class GlobalToken(torch.nn.Module):
    def __init__(self, data_info, d_model):
        super().__init__()
        self.data_info = data_info
        self.global_token = torch.nn.Parameter(torch.rand(1, d_model))
    
    def forward(self, reference_dict):
        reference = reference_dict[self.data_info.modality_info["target"][0]]
        batch_size, seq_len, d_model = reference.shape

        global_token_seq = self.global_token.unsqueeze(0).repeat(batch_size, seq_len, 1)

        return {"global":global_token_seq}

class PosModEncoding(torch.nn.Module):
    def __init__(self, col, pos_enc, modality_dict, modality_embedding, d_model):
        super().__init__()
        self.pos_enc = pos_enc
        self.modality = modality_dict[col]
        self.modality_embedding = modality_embedding

    def forward(self, key, val, device):
        # Positional encoding
        val = self.pos_enc(val)

        # Modality embedding
        modality = torch.zeros(val.shape[1]).to(torch.int).to(device) + self.modality
        modality = self.modality_embedding(modality)

        return val + modality

class Remain(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.temporal_remain = TemporalRemain()
        self.img_remain = ImgRemain()
        self.nlp_remain = NlpRemain()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, temporal_cols, img_cols, nlp_cols, device):
        temporal_remain_dict, idx_dict, padding_mask_dict = self.temporal_remain(data_dict, idx_dict, padding_mask_dict, remain_rto["temporal"], temporal_cols, device)
        img_remain_dict, idx_dict, padding_mask_dict = self.img_remain(data_dict, idx_dict, padding_mask_dict, remain_rto["img"], img_cols, device)
        nlp_remain_dict, padding_mask_dict = self.nlp_remain(data_dict, idx_dict, padding_mask_dict, remain_rto["nlp"], nlp_cols, device)

        return temporal_remain_dict, img_remain_dict, nlp_remain_dict, idx_dict, padding_mask_dict

class Encoder(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation, num_layers, to_decoder_dim):
        super().__init__()
        self.encoder_layers = torch.nn.ModuleList([copy.deepcopy(EncoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
        self.to_decoder_dim = to_decoder_dim
    
    def forward(self, temporal_remain_dict, img_remain_dict, nlp_remain_dict, padding_mask_dict, img_cols, nlp_cols):
        (src, temporal_block_remain_shape,
            temporal_unblock_remain_idx, static_unblock_remain_idx,
            temporal_block_remain_padding_mask, static_unblock_remain_padding_mask,
            img_unblock_remain_idx_li, nlp_unblock_remain_idx_li) = self.get_src(temporal_remain_dict, img_remain_dict, nlp_remain_dict, padding_mask_dict, img_cols, nlp_cols)
        
        temporal_block_remain_padding_mask = torch.where(temporal_block_remain_padding_mask==1, 0, -torch.inf)
        static_unblock_remain_padding_mask = torch.where(static_unblock_remain_padding_mask==1, 0, -torch.inf)
        for mod in self.encoder_layers:
            src = mod(src, temporal_block_remain_shape, 
                        temporal_unblock_remain_idx, static_unblock_remain_idx,
                        temporal_block_remain_padding_mask, static_unblock_remain_padding_mask)
        
        encoding = self.to_decoder_dim(src)
        temporal_dict, img_dict, nlp_dict = self.undo_src(encoding, temporal_block_remain_shape, temporal_unblock_remain_idx, static_unblock_remain_idx, img_unblock_remain_idx_li, nlp_unblock_remain_idx_li, img_cols, nlp_cols)
        return temporal_dict, img_dict, nlp_dict
    
    def get_src(self, temporal_remain_dict, img_remain_dict, nlp_remain_dict, padding_mask_dict, img_cols, nlp_cols):
        # Unblock temporal remain
        temporal_block_remain = temporal_remain_dict["temporal_remain"]
        temporal_block_remain_shape = temporal_block_remain.shape
        batch_size, seq_len, num_modality, d_model = temporal_block_remain_shape

        temporal_unblock_remain = temporal_block_remain.view(batch_size, -1, d_model)
        temporal_block_remain_padding_mask = padding_mask_dict["temporal_remain_padding_mask"]
        temporal_unblock_remain_idx = torch.arange(0, temporal_unblock_remain.shape[1]).to(device)

        # Unblock img remain
        img_unblock_remain_li, img_unblock_remain_padding_mask_li, img_unblock_remain_idx_li = [], [], []
        unblock_idx = temporal_unblock_remain_idx[-1] + 1
        for col in img_cols:
            img_remain = img_remain_dict[col]
            img_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]
            length = img_remain.shape[1]
            assert img_remain.shape[1] == img_padding_mask.shape[-1], f"{img_remain.shape}, {img_padding_mask.shape}"

            img_unblock_remain_li.append(img_remain)
            img_unblock_remain_padding_mask_li.append(img_padding_mask)
            img_unblock_remain_idx_li.append(torch.arange(unblock_idx, unblock_idx+length).to(device))
            unblock_idx += length
        img_unblock_remain = torch.cat(img_unblock_remain_li, dim=1)
        img_unblock_remain_padding_mask = torch.cat(img_unblock_remain_padding_mask_li, dim=1)
        img_unblock_remain_idx = torch.cat(img_unblock_remain_idx_li)

        # Unblock nlp remain
        nlp_unblock_remain_li, nlp_unblock_remain_padding_mask_li, nlp_unblock_remain_idx_li = [], [], []
        for col in nlp_cols:
            nlp_remain = nlp_remain_dict[col]
            nlp_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]
            length = nlp_remain.shape[1]
            assert nlp_remain.shape[1] == nlp_padding_mask.shape[-1], f"{nlp_remain.shape}, {nlp_padding_mask.shape}"

            nlp_unblock_remain_li.append(nlp_remain)
            nlp_unblock_remain_padding_mask_li.append(nlp_padding_mask)
            nlp_unblock_remain_idx_li.append(torch.arange(unblock_idx, unblock_idx+length).to(device))
            unblock_idx += length
        nlp_unblock_remain = torch.cat(nlp_unblock_remain_li, dim=1)
        nlp_unblock_remain_padding_mask = torch.cat(nlp_unblock_remain_padding_mask_li, dim=1)
        nlp_unblock_remain_idx = torch.cat(nlp_unblock_remain_idx_li)

        # Static
        static_unblock_remain = torch.cat([img_unblock_remain, nlp_unblock_remain], dim=1)
        static_unblock_remain_padding_mask = torch.cat([img_unblock_remain_padding_mask, nlp_unblock_remain_padding_mask], dim=1)
        static_unblock_remain_idx = torch.cat([img_unblock_remain_idx, nlp_unblock_remain_idx])

        # Src
        src = torch.cat([temporal_unblock_remain, static_unblock_remain], dim=1)
        
        return (src, temporal_block_remain_shape,
                temporal_unblock_remain_idx, static_unblock_remain_idx,
                temporal_block_remain_padding_mask, static_unblock_remain_padding_mask,
                img_unblock_remain_idx_li, nlp_unblock_remain_idx_li)
        
    def undo_src(self, data, temporal_block_remain_shape, temporal_unblock_remain_idx, static_unblock_remain_idx, img_unblock_remain_idx_li, nlp_unblock_remain_idx_li, img_cols, nlp_cols):
        # Temporal
        temporal_dict = {}
        temporal_unblock_idx = temporal_unblock_remain_idx.unsqueeze(0).unsqueeze(-1).repeat(data.shape[0], 1, data.shape[-1])
        temporal_encoding = torch.gather(data, index=temporal_unblock_idx, dim=1).view(temporal_block_remain_shape[0], temporal_block_remain_shape[1], temporal_block_remain_shape[2], -1)
        temporal_dict["temporal"] = temporal_encoding

        # Img
        img_dict = {}
        for col, idx in zip(img_cols, img_unblock_remain_idx_li):
            img_unblock_idx = idx.unsqueeze(0).unsqueeze(-1).repeat(data.shape[0], 1, data.shape[-1])
            img_block_data = torch.gather(data, index=img_unblock_idx, dim=1)
            img_dict[col] = img_block_data

        # Nlp
        nlp_dict = {}
        for col, idx in zip(nlp_cols, nlp_unblock_remain_idx_li):
            nlp_unblock_idx = idx.unsqueeze(0).unsqueeze(-1).repeat(data.shape[0], 1, data.shape[-1])
            nlp_block_data = torch.gather(data, index=nlp_unblock_idx, dim=1)
            nlp_dict[col] = nlp_block_data
        
        return temporal_dict, img_dict, nlp_dict

class Revert(torch.nn.Module):
    def __init__(self, mask_token):
        super().__init__()
        self.temporal_revert = TemporalRevert(mask_token)
        self.img_revert = ImgRevert(mask_token)
        self.nlp_revert = NlpRevert(mask_token)
    
    def forward(self, temporal_dict, img_dict, nlp_dict, idx_dict, padding_mask_dict, temporal_cols_g, img_cols, nlp_cols):
        temporal_revert_dict = self.temporal_revert(temporal_dict["temporal"], idx_dict, temporal_cols_g)
        img_revert_dict = self.img_revert(img_dict, idx_dict, img_cols)
        nlp_revert_dict = self.nlp_revert(nlp_dict, idx_dict, padding_mask_dict, nlp_cols)

        revert_dict = {}
        revert_dict.update(temporal_revert_dict)
        revert_dict.update(img_revert_dict)
        revert_dict.update(nlp_revert_dict)
        
        return revert_dict


class Decoder(torch.nn.Module):
    def __init__(self, col, d_model, nhead, d_ff, dropout, activation, num_layers):
        super().__init__()
        self.temporal_decoder_layers = torch.nn.ModuleList([copy.deepcopy(TemporalDecoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
        self.non_temporal_decoder_layers = torch.nn.ModuleList([copy.deepcopy(NonTemporalDecoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
    
    def forward(self, key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device):
        # Temporal
        if key in temporal_cols:
            tgt = val
            memory, cross_attn_padding_mask, self_attn_padding_mask = self.get_temporal_tgt_memory(key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device)
            decoder_layers = self.temporal_decoder_layers
        elif key in img_cols + nlp_cols:
            tgt = val
            memory, cross_attn_padding_mask, self_attn_padding_mask = self.get_non_temporal_tgt_memory(key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device)
            decoder_layers = self.non_temporal_decoder_layers

        for mod in decoder_layers:
            tgt, cross_attn_weight, self_attn_weight = mod(tgt, memory, cross_attn_padding_mask, self_attn_padding_mask)
        
        return tgt, cross_attn_weight, self_attn_weight

    def get_temporal_tgt_memory(self, key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device):
        memory_li = []

        # Temporal memory
        for col in temporal_cols:
            if col == key:
                continue
            memory_li.append(data_dict[col])
        memory = torch.stack(memory_li, dim=-2)
        cross_attn_padding_mask = torch.ones(memory.shape[:-1]).to(device)
        
        # Static memory
        for col in img_cols + nlp_cols:
            static_data = data_dict[col].unsqueeze(1).repeat(1, memory.shape[1], 1, 1)
            static_padding_msak = padding_mask_dict[f"{col}_revert_padding_mask"].unsqueeze(1).repeat(1, memory.shape[1], 1)

            memory = torch.cat([memory, static_data], dim=-2)
            cross_attn_padding_mask = torch.cat([cross_attn_padding_mask, static_padding_msak], dim=-1)
        
        # Self attn padding mask
        self_attn_padding_mask = padding_mask_dict["temporal_revert_padding_mask"].squeeze()
        self_attn_padding_mask = torch.where(self_attn_padding_mask==1, 0, -torch.inf)
        cross_attn_padding_mask = torch.where(cross_attn_padding_mask==1, 0, -torch.inf)

        return memory, cross_attn_padding_mask, self_attn_padding_mask

    def get_non_temporal_tgt_memory(self, key, val, data_dict, padding_mask_dict, temporal_cols, img_cols, nlp_cols, device):
        memory_li, cross_attn_padding_mask_li = [], []
        
        # Temporal memory
        for col in temporal_cols:
            temporal_data = data_dict[col]
            temporal_padding_mask = padding_mask_dict[f"temporal_revert_padding_mask"].squeeze()
            
            memory_li.append(temporal_data)
            cross_attn_padding_mask_li.append(temporal_padding_mask)
        
        memory = torch.cat(memory_li, dim=1)
        cross_attn_padding_mask = torch.cat(cross_attn_padding_mask_li, dim=-1)

        # Img memory or Nlp memory
        for col in img_cols + nlp_cols:
            if key == col: continue
            static_data = data_dict[col]
            static_padding_mask = padding_mask_dict[f"{col}_revert_padding_mask"]

            memory = torch.cat([memory, static_data], dim=1)
            print(cross_attn_padding_mask.shape)
            print(static_padding_mask.shape)
            raise
            cross_attn_padding_mask = torch.cat([cross_attn_padding_mask, static_padding_mask], dim=-1)
        
        # Self attn padding mask
        self_attn_padding_mask = padding_mask_dict[f"{key}_revert_padding_mask"].squeeze()
        self_attn_padding_mask = torch.where(self_attn_padding_mask==1, 0, -torch.inf)
        cross_attn_padding_mask = torch.where(cross_attn_padding_mask==1, 0, -torch.inf)

        return memory, cross_attn_padding_mask, self_attn_padding_mask

class TemporalOutput(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.output = torch.nn.Sequential(
                            torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, 1))
        
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.output = torch.nn.Sequential(
                            torch.nn.Linear(d_model, d_model),
                            torch.nn.Linear(d_model, num_cls))

    
    def forward(self, key, val):
        return self.output(val)

class ImgOutput(torch.nn.Module):
    def __init__(self, col, d_model, patch_size=16):
        super().__init__()
        self.output = torch.nn.Sequential(
            torch.nn.Linear(d_model, d_model),
            torch.nn.Linear(d_model, 3*patch_size*patch_size))
    
    def forward(self, key, val):
        return self.output(val)
        
class NlpOutput(torch.nn.Module):
    def __init__(self, col, d_model, num_vocab=30522):
        super().__init__()
        self.output = torch.nn.Sequential(
            torch.nn.Linear(d_model, d_model),
            torch.nn.Linear(d_model, num_vocab))
    
    def forward(self, key, val):
        return self.output(val)
1==1

1==1

True

In [36]:
import torch

class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict,
                d_model, num_layers, nhead, d_ff, dropout, activation):
        super().__init__()
        self.data_info, self.label_encoder_dict = data_info, label_encoder_dict
        self.temporal_cols, self.temporal_cols_g, self.img_cols, self.nlp_cols, self.total_cols, self.total_cols_g = self.define_cols()
        self.self_attn_weight_dict, self.cross_attn_weight_dict = {}, {}

        # 1. Embedding
        self.embedding_dict = self.init_process(mod=Embedding, args=[data_info, label_encoder_dict, d_model["encoder"]], target_cols=self.total_cols)
        # 2. Global token
        self.global_token_dict = GlobalToken(data_info, d_model["encoder"])
        # 3. Pos encoding & Modality encoding
        encoder_pos_enc = PositionalEncoding(d_model["encoder"], dropout)
        num_modality = len(self.total_cols_g)
        modality_dict = {col:n for n, col in enumerate(self.total_cols_g)}
        enc_modality_embedding = torch.nn.Embedding(num_modality, d_model["encoder"])
        self.enc_posmod_encoding_dict = self.init_process(mod=PosModEncoding, args=[encoder_pos_enc, modality_dict, enc_modality_embedding, d_model["encoder"]], target_cols=self.total_cols_g)
        # 4. Remain masking
        self.remain_dict = Remain()
        # 5. Encoding
        to_decoder_dim = torch.nn.Linear(d_model["encoder"], d_model["decoder"])
        self.encoding = Encoder(d_model["encoder"], nhead, d_ff["encoder"], dropout, activation, num_layers["encoder"], to_decoder_dim)
        # 6. Revert
        mask_token = torch.nn.Parameter(torch.rand(1, d_model["decoder"]))
        self.revert = Revert(mask_token)
        # 7. Pos encoding & Modality encoding
        decoder_pos_enc = PositionalEncoding(d_model["decoder"], dropout)
        dec_modality_embedding = torch.nn.Embedding(num_modality, d_model["decoder"])
        self.dec_posmod_encoding_dict = self.init_process(mod=PosModEncoding, args=[decoder_pos_enc, modality_dict, dec_modality_embedding, d_model["decoder"]], target_cols=self.total_cols_g)
       
        # 7. Decoding
        self.decoding = self.init_process(mod=Decoder, args=[d_model["decoder"], nhead, d_ff["decoder"], dropout, activation, num_layers["decoder"]], target_cols=self.total_cols_g)
    
    def forward(self, data_input, remain_rto, device):
        data_dict, self.idx_dict, self.padding_mask_dict = self.to_gpu(data_input, device)
        
        # 1. Embedding
        embedding_dict = self.apply_process(data=data_dict, mod=self.embedding_dict, args=[self.padding_mask_dict, device], target_cols=self.total_cols)
        # 2. Global & dummy token
        global_token_dict = self.global_token_dict(embedding_dict)
        embedding_dict.update(global_token_dict)
        # 3. Pos encoding & Modality encoding
        enc_posmod_encoding_dict = self.apply_process(data=embedding_dict, mod=self.enc_posmod_encoding_dict, args=[device], target_cols=self.total_cols_g)
        # 4. Remain masing
        temporal_remain_dict, img_remain_dict, nlp_remain_dict, self.idx_dict, self.padding_mask_dict = self.remain_dict(enc_posmod_encoding_dict, self.idx_dict, self.padding_mask_dict, remain_rto, self.temporal_cols_g, self.img_cols, self.nlp_cols, device)
        # 5. Encoding
        temporal_encoding_dict, img_encoding_dict, nlp_encoding_dict = self.encoding(temporal_remain_dict, img_remain_dict, nlp_remain_dict, self.padding_mask_dict, self.img_cols, self.nlp_cols)
        # 6. Revert
        revert_dict = self.revert(temporal_encoding_dict, img_encoding_dict, nlp_encoding_dict, self.idx_dict, self.padding_mask_dict, self.temporal_cols_g, self.img_cols, self.nlp_cols)
        # 7. Pos encoding & Modality encoding
        dec_posmod_encoding_dict = self.apply_process(data=revert_dict, mod=self.dec_posmod_encoding_dict, args=[device], target_cols=self.total_cols_g)

        # 7. Decoder
        decoding = self.apply_process(data=dec_posmod_encoding_dict, mod=self.decoding, args=[dec_posmod_encoding_dict, self.padding_mask_dict, self.temporal_cols_g, self.img_cols, self.nlp_cols, device], collate_fn=self.tidy_decoding, target_cols=self.total_cols_g)

    
    def define_cols(self):
        temporal_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        img_cols = self.data_info.modality_info["img"]
        nlp_cols = self.data_info.modality_info["nlp"]
        total_cols = temporal_cols + img_cols + nlp_cols
        
        total_cols_gd = ["global"] + total_cols
        temporal_cols_gd = ["global"] + temporal_cols

        return temporal_cols, temporal_cols_gd, img_cols, nlp_cols, total_cols, total_cols_gd

    def to_gpu(self, data_input, device):
        data_dict, idx_dict, padding_mask_dict = {}, {}, {}
        for key, val in data_input.items():
            if key in self.temporal_cols + self.img_cols + self.nlp_cols:
                data_dict[key] = data_input[key].to(device)
            elif key.endswith("idx"):
                idx_dict[key] = data_input[key].to(device)
            elif key.endswith("mask"):
                padding_mask_dict[key] = data_input[key].to(device)
            
        return data_dict, idx_dict, padding_mask_dict

    def init_process(self, mod, args=[], target_cols=None):
        result_dict = {}
        for col in target_cols:
            result_dict[col] = mod(col, *args)
        
        return torch.nn.ModuleDict(result_dict)

    def apply_process(self, data, mod, args=[], target_cols=None, collate_fn=None):
        result_dict = {}
        for col in target_cols:
            result_dict[col] = mod[col](col, data[col], *args)
        
        if collate_fn is not None:
            return collate_fn(result_dict)
        else:
            return result_dict

    def tidy_decoding(self, decoding_dict):
        result_dict = {}
        for key, val in decoding_dict.items():
            result_dict[key], self_attn_weight, cross_attn_weight = val
            self.self_attn_weight_dict.update({key:self_attn_weight})
            self.cross_attn_weight_dict.update({key:cross_attn_weight})

        return result_dict
1==1

True

In [37]:
model = Transformer(data_info, train_dataset.label_encoder_dict,
                        d_model, num_layers, nhead, d_ff, dropout, "gelu")
model.to(device)
summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

torch.Size([2, 50, 49])
torch.Size([2, 9])


RuntimeError: No active exception to reraise

# Train

In [None]:
for name, param in model.named_parameters():
    if "img_model" in name:
        param.requires_grad = False
    elif "nlp_model" in name:
        param.requires_grad = False

In [None]:
from train import *

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

epoch = 3
for e in range(epoch):
    train(e, train_dataloader, optimizer, scheduler, model, data_info, remain_rto, device)
    scheduler.step()
    # raise

  0%|          | 0/1 [00:00<?, ?it/s]


TypeError: cannot unpack non-iterable NoneType object