### Import

In [1]:
import copy

from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import torch
from pytorch_model_summary import summary
from transformers import ViTModel, AutoImageProcessor, BertModel, AutoTokenizer

from rawdata import RawData, Preprocess
from data import DataInfo, Dataset, collate_fn
from data import NoneScaler, LogScaler, CustomLabelEncoder

from architecture import Transformer

import cv2

device = torch.device("cuda")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


### Config

In [2]:
test_mode = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 100
PRED_LEN = 100

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "img": ["img_path"],
    "nlp": ["detail_desc"]
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
    "img_cols": ["img_path"],
    "nlp_cols": ["detail_desc"]
}

# Model
batch_size = 16
nhead = 4
dropout = 0.1

d_model = {"encoder":256, "decoder":128}
d_ff = {"encoder":256, "decoder":128}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"temporal":0.75, "img":0.25, "nlp":0.25}

# Data

### Raw data

In [3]:
if test_mode:
    df_preprocessed = pd.read_parquet("src/df_preprocessed_test.parquet")
else:
    if not is_prep_data_exist:
        rawdata = RawData()
        df_trans, df_meta, df_holiday = rawdata.get_raw_data()
        preprocess = Preprocess(df_trans, df_meta, df_holiday)
        df_preprocessed = preprocess.main()
    else:
        df_preprocessed = pd.read_parquet("src/df_preprocessed.parquet")

### Dataset

In [4]:
df_train = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1)]
df_train = df_train[~pd.isna(df_train["detail_desc"])]
df_valid = df_preprocessed[(df_preprocessed["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN) & (df_preprocessed["time_idx"] <= MAX_SEQ_LEN-1 + PRED_LEN)]
df_valid = df_valid[~pd.isna(df_valid["detail_desc"])]

data_info = DataInfo(modality_info, processing_info)

In [5]:
train_dataset = Dataset(df_train, data_info, remain_rto)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info))

for data in train_dataloader:
    [print(key, val.shape) for key, val in data.items() if "scaler" not in key and "raw" not in key]
    break

100%|██████████| 29274/29274 [00:00<00:00, 30087.38it/s]


sales torch.Size([16, 100, 1])
day torch.Size([16, 100])
dow torch.Size([16, 100])
month torch.Size([16, 100])
holiday torch.Size([16, 100])
price torch.Size([16, 100, 1])
temporal_padding_mask torch.Size([16, 100, 1])
img_path torch.Size([16, 3, 224, 224])
detail_desc torch.Size([16, 57])
detail_desc_remain_idx torch.Size([16, 14])
detail_desc_masked_idx torch.Size([16, 42])
detail_desc_revert_idx torch.Size([16, 56])
detail_desc_remain_padding_mask torch.Size([16, 14])
detail_desc_masked_padding_mask torch.Size([16, 42])
detail_desc_revert_padding_mask torch.Size([16, 56])


# Architecture

In [6]:
import math
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = torch.permute(x, (1,0,2))
        x = x + self.pe[:x.size(0)]
        x = torch.permute(x, (1,0,2))
        return self.dropout(x)

class MultiheadBlockSelfAttention(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.nhead = nhead

        self.q_linear = torch.nn.Linear(d_model, d_model)
        self.k_linear = torch.nn.Linear(d_model, d_model)
        self.v_linear = torch.nn.Linear(d_model, d_model)

        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, query, key, value, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li, temporal_side=True, non_temporal_side=True):
        # Linear transformation
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        if temporal_side:
            temporal_attn_output, temporal_attn_weight = self.temporal_side_attn(Q, K, V, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
            print("temporal_attn_output:", temporal_attn_output.shape)
            print("temporal_attn_weight:", temporal_attn_weight.shape)
        
        if non_temporal_side:
            non_temporal_attn_output, non_temporal_attn_weight = self.non_temporal_side_attn(Q, K, V, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
    
    def temporal_side_attn(self, query, key, value, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li):
        # Q
        Q = torch.gather(query, index=temporal_idx.unsqueeze(0).unsqueeze(-1).repeat(query.shape[0], 1, query.shape[-1]), dim=1)
        Q = Q.view(temporal_shape)

        # K and V
        def get_KV(KV, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li):
            KV_temporal = torch.gather(query, index=temporal_idx.unsqueeze(0).unsqueeze(-1).repeat(query.shape[0], 1, query.shape[-1]), dim=1)
            KV_temporal = KV_temporal.view(temporal_shape)

            KV_img_li = []
            for img_idx in img_idx_li:
                KV_img_temp = torch.gather(KV, index=img_idx.unsqueeze(0).unsqueeze(-1).repeat(KV.shape[0], 1, KV.shape[-1]), dim=1)
                KV_img_li.append(KV_img_temp)
            KV_img = torch.cat(KV_img_li, dim=1)

            KV_nlp_li = []
            for nlp_idx in nlp_idx_li:
                KV_nlp_temp = torch.gather(KV, index=nlp_idx.unsqueeze(0).unsqueeze(-1).repeat(KV.shape[0], 1, KV.shape[-1]), dim=1)
                KV_nlp_li.append(KV_nlp_temp)
            KV_nlp = torch.cat(KV_nlp_li, dim=1)

            KV_non_temporal = torch.cat([KV_img, KV_nlp], dim=1)
            KV_non_temporal = KV_non_temporal.unsqueeze(1).repeat(1, KV_temporal.shape[1], 1, 1)
            
            KV = torch.cat([KV_temporal, KV_non_temporal], dim=-2)

            temporal_padding_mask = temporal_padding_mask.view(temporal_shape[:-1])
            img_padding_mask = img_padding_mask.unsqueeze(1).repeat(1, KV.shape[1], 1)
            nlp_padding_mask = nlp_padding_mask.unsqueeze(1).repeat(1, KV.shape[1], 1)

            padding_mask = torch.cat([temporal_padding_mask, img_padding_mask, nlp_padding_mask], dim=-1)
            padding_mask = torch.where(padding_mask == 1, 0, -torch.inf)

            return KV, padding_mask
        
        K, padding_mask = get_KV(key, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
        V, _ = get_KV(value, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
    
        # Split head
        batch_size, seq_len, _, d_model = Q.shape
        Q = Q.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        K = K.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)
        V = V.view(batch_size, seq_len, -1, self.nhead, d_model//self.nhead).permute(0, 3, 1, 2, 4)

        # Scaled dot product attention
        ### 1. Q·K^t
        QK = Q @ K.permute(0,1,2,4,3)
        logits = QK / math.sqrt(d_model//self.nhead)
        padding_mask = padding_mask.unsqueeze(1).unsqueeze(-2).repeat(1, logits.shape[1], 1, logits.shape[-2], 1)
        logits += padding_mask

        ### 2. Softmax
        attn_weight = torch.nn.functional.softmax(logits, dim=-1)

        ### 3. Matmul V
        attn_output = attn_weight @ V

        ### 4. Concat heads
        attn_output = attn_output.permute(0,2,3,1,4).reshape(batch_size, seq_len, -1, d_model)

        return attn_output, attn_weight

        

1==1

True

In [7]:
class TemporalEmbedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.processing_info["scaling_cols"]:
            self.embedding = torch.nn.Linear(1, d_model)
        
        elif col in data_info.processing_info["embedding_cols"]:
            num_cls = label_encoder_dict[col].get_num_cls()
            self.embedding = torch.nn.Embedding(num_cls, d_model)
    
    def forward(self, key, val, padding_mask_dict, device):
        return self.embedding(val)

class ImgEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.downsize_linear = torch.nn.Linear(768, d_model)
    
    def forward(self, key, val, padding_mask_dict, device):
        embedding = self.img_model(val).last_hidden_state
        embedding = self.downsize_linear(embedding)
        return embedding

class NlpEmbedding(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.nlp_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
        self.downsize_linear = torch.nn.Linear(768, d_model)   
    
    def forward(self, key, val, padding_mask_dict, device):
        # Make token_type_ids
        token_type_ids = torch.zeros(val.shape).to(torch.int).to(device)

        # Make attention mask
        attention_mask = padding_mask_dict[f"{key}_revert_padding_mask"]
        mask_for_global_token = torch.ones(attention_mask.shape[0], 1).to(device)
        attention_mask = torch.cat([attention_mask, mask_for_global_token], dim=-1)

        # Embed data
        inputs = {"input_ids":val, "token_type_ids":token_type_ids, "attention_mask":attention_mask}
        embedding = self.nlp_model(**inputs).last_hidden_state
        embedding = self.downsize_linear(embedding)
        
        return embedding


class TemporalRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, remain_rto, temporal_cols, device):
        result_dict, idx_dict = {}, {}
        # Concat data
        concat_data_li = []
        for col in temporal_cols:
            concat_data_li.append(data_dict[col])
        
        concat_data = torch.stack(concat_data_li, dim=-2)
    
        # Remain mask
        num_modality = concat_data.shape[-2]
        num_remain = int(num_modality * remain_rto)
        
        noise = torch.rand(concat_data.shape[:-1]).to(device)
        shuffle_idx = torch.argsort(noise, dim=-1)

        remain_idx = shuffle_idx[:, :, :num_remain]
        masked_idx = shuffle_idx[:, :, num_remain:]
        revert_idx = torch.argsort(shuffle_idx, dim=-1)

        # Apply mask
        concat_data = torch.gather(concat_data, index=remain_idx.unsqueeze(-1).repeat(1, 1, 1, concat_data.shape[-1]), dim=-2)

        result_dict["temporal"] = concat_data
        idx_dict.update({"temporal_remain_idx":remain_idx, "temporal_masked_idx":masked_idx, "temporal_revert_idx":revert_idx})

        return result_dict, idx_dict

class ImgRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, remain_rto, img_cols, device):
        result_dict, idx_dict, padding_mask_dict = {}, {}, {}
        # Get indexs
        for col in img_cols:
            val = data_dict[col]
            num_remain = int(val.shape[1] * remain_rto)
            noise = torch.rand(val.shape[0], val.shape[1]).to(device)
            shuffle_idx = torch.argsort(noise, dim=1)

            remain_idx = shuffle_idx[:, :num_remain]
            masked_idx = shuffle_idx[:, num_remain:]
            revert_idx = torch.argsort(shuffle_idx, dim=1)

            remain_padding_mask = torch.ones(remain_idx.shape).to(device)
            revert_padding_mask = torch.ones(revert_idx.shape).to(device)

            # Apply mask
            val = torch.gather(val, index=remain_idx.unsqueeze(-1).repeat(1, 1, val.shape[-1]), dim=1)

            result_dict[col] = val
            idx_dict.update({f"{col}_remain_idx":remain_idx, f"{col}_masked_idx":masked_idx, f"{col}_revert_idx":revert_idx})
            padding_mask_dict.update({f"{col}_remain_padding_mask":remain_padding_mask, f"{col}_revert_padding_mask":revert_padding_mask})

        return result_dict, idx_dict, padding_mask_dict

class NlpRemain(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, data_dict, idx_dict, remain_rto, nlp_cols, device):
        result_dict = {}
        for col in nlp_cols:
            val = data_dict[col]
            remain_idx = idx_dict[f"{col}_remain_idx"].unsqueeze(-1).repeat(1, 1, val.shape[-1])
            val = torch.gather(val, index=remain_idx, dim=1)
            result_dict[col] = val
        return result_dict


class TemporalEncoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation):
        super().__init__()
        self.self_attn = MultiheadBlockSelfAttention(d_model, nhead, dropout)


        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.dropout1 = torch.nn.Dropout(dropout)
        
        # Feed forward
        if activation == "relu":
            self.activation = torch.nn.ReLU()
        if activation == "gelu":
            self.activation = torch.nn.GELU()
        
        self.linear_ff1 = torch.nn.Linear(d_model, d_ff)
        self.linear_ff2 = torch.nn.Linear(d_ff, d_model)
        self.dropout_ff1 = torch.nn.Dropout(dropout)
        self.dropout_ff2 = torch.nn.Dropout(dropout)

    def forward(self, src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li):
        attn_output, attn_weight = self._sa_block(self.norm1(src), temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
        x = x + self_attn_output

        x = x + self._ff_block(self.norm2(x))
        return x, attn_weight

    def _sa_block(self, src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li):
        x, attn_weight = self.self_attn(src, src, src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
        return self.dropout1(x), attn_weight

    def _ff_block(self, x):
        x = self.linear_ff2(self.dropout_ff1(self.activation(self.linear_ff1(x))))
        return self.dropout_ff2(x)

1==1

True

In [8]:
arr = torch.tensor([
    [1,2,3,4,5,6,7,8,9,10],
    [11,12,13,14,15,16,17,18,19,20],
    [21,22,23,24,25,26,27,28,29,30]
])

temporal = torch.tensor([0,1,2,3,4]).unsqueeze(0).repeat(arr.shape[0], 1)
img = torch.tensor([5,6,7]).unsqueeze(0).repeat(arr.shape[0], 1)
nlp = torch.tensor([8,9]).unsqueeze(0).repeat(arr.shape[0], 1)

print(arr.shape)
print(temporal.shape)

torch.gather(arr, index=img, dim=-1)

torch.Size([3, 10])
torch.Size([3, 5])


tensor([[ 6,  7,  8],
        [16, 17, 18],
        [26, 27, 28]])

In [9]:
class Embedding(torch.nn.Module):
    def __init__(self, col, data_info, label_encoder_dict, d_model):
        super().__init__()
        if col in data_info.modality_info["target"] + data_info.modality_info["temporal"]:
            self.embedding = TemporalEmbedding(col, data_info, label_encoder_dict, d_model)
        elif col in data_info.modality_info["img"]:
            self.embedding = ImgEmbedding(d_model)
        elif col in data_info.modality_info["nlp"]:
            self.embedding = NlpEmbedding(d_model)
    
    def forward(self, key, val, padding_mask, device):
        return self.embedding(key, val, padding_mask, device)

class PosModEncoding(torch.nn.Module):
    def __init__(self, col, pos_enc, num_modality, modality, d_model):
        super().__init__()
        self.pos_enc = pos_enc
        self.modality = modality[col]
        self.modality_embedding = torch.nn.Embedding(num_modality, d_model)

    def forward(self, key, val, device):
        # Positional encoding
        val = self.pos_enc(val)

        # Modality embedding
        modality = torch.zeros(val.shape[1]).to(torch.int).to(device) + self.modality
        modality = self.modality_embedding(modality)

        return val + modality

class Remain(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.temporal_remain = TemporalRemain()
        self.img_remain = ImgRemain()
        self.nlp_remain = NlpRemain()
    
    def forward(self, data_dict, idx_dict, padding_mask_dict, remain_rto, temporal_cols, img_cols, nlp_cols, device):
        temporal_dict, temporal_idx_dict = self.temporal_remain(data_dict, remain_rto["temporal"], temporal_cols, device)
        idx_dict.update(temporal_idx_dict)

        img_dict, img_idx_dict, img_padding_mask_dict = self.img_remain(data_dict, remain_rto["img"], img_cols, device)
        idx_dict.update(img_idx_dict)
        padding_mask_dict.update(img_padding_mask_dict)

        nlp_dict = self.nlp_remain(data_dict, idx_dict, remain_rto["nlp"], nlp_cols, device)

        return temporal_dict, img_dict, nlp_dict, idx_dict, padding_mask_dict

class Encoder(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, activation, num_layers):
        super().__init__()
        self.encoder_layers = torch.nn.ModuleList([copy.deepcopy(TemporalEncoderLayer(d_model, nhead, d_ff, dropout, activation)) for _ in range(num_layers)])
    
    def forward(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li = self.get_src(temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device)
        for mod in self.encoder_layers:
            src, attn_weight = mod(src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li)
    
    def get_src(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        # Temporal
        temporal = temporal_dict["temporal"]
        temporal_shape = temporal.shape
        batch_size, seq_len, _, d_model = temporal_shape
        temporal = temporal.view(batch_size, -1, d_model)
        temporal_padding_mask = torch.ones(temporal.shape[:-1]).to(device)
        temporal_idx = torch.arange(0, temporal.shape[1]).to(device)

        # Image
        idx = temporal_idx[-1]+1
        img_li, img_padding_mask_li, img_idx_li = [], [], []
        for col in img_cols:
            img_data = img_dict[col]
            img_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]

            img_li.append(img_data)
            img_padding_mask_li.append(img_padding_mask)
            img_idx_li.append(torch.arange(idx, idx+img_data.shape[1]).to(device))
            idx += img_data.shape[1]

        img = torch.cat(img_li, dim=1)
        img_padding_mask = torch.cat(img_padding_mask_li, dim=1)

        # Nlp
        nlp_li, nlp_padding_mask_li, nlp_idx_li = [], [], []
        for col in nlp_cols:
            nlp_data = nlp_dict[col]
            nlp_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]

            nlp_li.append(nlp_data)
            nlp_padding_mask_li.append(nlp_padding_mask)
            nlp_idx_li.append(torch.arange(idx, idx+nlp_data.shape[1]).to(device))
            idx += nlp_data.shape[1]

        nlp = torch.cat(nlp_li, dim=1)
        nlp_padding_mask = torch.cat(nlp_padding_mask_li, dim=1)
        
        src = torch.cat([temporal, img, nlp], dim=1)
        assert idx == src.shape[1], f"{idx}, {src.shape}"
        
        return src, temporal_shape, temporal_padding_mask, img_padding_mask, nlp_padding_mask, temporal_idx, img_idx_li, nlp_idx_li



    def __forward(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        # Temporal
        temporal_query, temporal_keyval, temporal_keyval_padding_mask = self.get_temporal_qkv(temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device)
        
        for mod in self.temporal_encoder_layers:
            temporal_query, temporal_keyval = mod(temporal_query, temporal_keyval, temporal_keyval_padding_mask)
        return
    
    def __get_temporal_qkv(self, temporal_dict, img_dict, nlp_dict, padding_mask_dict, img_cols, nlp_cols, device):
        query = temporal_dict["temporal"]

        # Get non temporal keyval
        non_temporalkeyval_data_li, non_temporal_keyval_padding_mask_li = [], []
        for col in img_cols+nlp_cols:
            non_temporal_keyval_data = img_dict[col] if col in img_cols else nlp_dict[col]
            non_temporal_keyval_padding_mask = padding_mask_dict[f"{col}_remain_padding_mask"]
            non_temporalkeyval_data_li.append(non_temporal_keyval_data)
            non_temporal_keyval_padding_mask_li.append(non_temporal_keyval_padding_mask)

        non_temporal_keyval_data = torch.cat(non_temporalkeyval_data_li, dim=1)
        non_temporal_keyval_padding_mask = torch.cat(non_temporal_keyval_padding_mask_li, dim=1)

        # Get temporal keyval
        temporal_keyval_data = temporal_dict["temporal"]
        temporal_keyval_padding_mask = torch.ones(temporal_keyval_data.shape[:-1]).to(device)

        # Get total keyval
        non_temporal_keyval_data = non_temporal_keyval_data.unsqueeze(1).repeat(1, query.shape[1], 1, 1)
        non_temporal_keyval_padding_mask = torch.cat(non_temporal_keyval_padding_mask_li, dim=1).unsqueeze(1).repeat(1, query.shape[1], 1)

        total_keyval_data = torch.cat([temporal_keyval_data, non_temporal_keyval_data], dim=-2)
        total_keyval_padding_mask = torch.cat([temporal_keyval_padding_mask, non_temporal_keyval_padding_mask], dim=-1)

        return query, total_keyval_data, total_keyval_padding_mask


1==1

True

In [10]:
class Transformer(torch.nn.Module):
    def __init__(self, data_info, label_encoder_dict,
                d_model, num_layers, nhead, d_ff, dropout, activation):
        super().__init__()
        self.data_info, self.label_encoder_dict = data_info, label_encoder_dict
        self.temporal_cols, self.img_cols, self.nlp_cols, self.total_cols = self.define_cols()
        self.self_attn_weight_dict, self.cross_attn_weight_dict = {}, {}
        
        # 1. Embedding
        self.embedding_dict = self.init_process(mod=Embedding, args=[self.data_info, self.label_encoder_dict, d_model["encoder"]])
        # 2. Pos encoding and modality embedding
        num_modality = len(self.total_cols)
        modality = {col:n for n, col in enumerate(self.total_cols)}
        encoder_pos_enc = PositionalEncoding(d_model["encoder"], dropout)
        self.pos_mod_encoding_dict = self.init_process(mod=PosModEncoding, args=[encoder_pos_enc, num_modality, modality, d_model["encoder"]])
        # 3. Remain mask
        self.remain_dict = Remain()
        # 4. Encoding
        self.encoding = Encoder(d_model["encoder"], nhead, d_ff["encoder"], dropout, activation, num_layers["encoder"])

    def forward(self, data_input, remain_rto, device):
        data_dict, self.idx_dict, self.padding_mask_dict = self.to_gpu(data_input, device)
        
        # 1. Embedding
        embedding_dict = self.apply_process(data=data_dict, mod=self.embedding_dict, args=[self.padding_mask_dict, device])
        # 2. Pos encoding and modality embedding
        pos_mod_encoding_dict = self.apply_process(data=embedding_dict, mod=self.pos_mod_encoding_dict, args=[device])
        # 3. Remain mask
        temporal_dict, img_dict, nlp_dict, self.idx_dict, self.padding_mask_dict = self.remain_dict(pos_mod_encoding_dict, self.idx_dict, self.padding_mask_dict, remain_rto, self.temporal_cols, self.img_cols, self.nlp_cols, device)
        # 4. Encoding
        encoding = self.encoding(temporal_dict, img_dict, nlp_dict, self.padding_mask_dict, self.img_cols, self.nlp_cols, device)


    
    def define_cols(self):
        temporal_cols = self.data_info.modality_info["target"] + self.data_info.modality_info["temporal"]
        img_cols = self.data_info.modality_info["img"]
        nlp_cols = self.data_info.modality_info["nlp"]
        total_cols = temporal_cols + img_cols + nlp_cols

        return temporal_cols, img_cols, nlp_cols, total_cols

    def to_gpu(self, data_input, device):
        data_dict, idx_dict, padding_mask_dict = {}, {}, {}
        for key, val in data_input.items():
            if key in self.temporal_cols + self.img_cols + self.nlp_cols:
                data_dict[key] = data_input[key].to(device)
            elif key.endswith("idx"):
                idx_dict[key] = data_input[key].to(device)
            elif key.endswith("padding_mask"):
                padding_mask_dict[key] = data_input[key].to(device)
            
        return data_dict, idx_dict, padding_mask_dict

    def init_process(self, mod, args=[], target_cols=None):
        result_dict = {}
        target_cols = self.total_cols if target_cols is None else target_cols
        for col in target_cols:
            result_dict[col] = mod(col, *args)
        
        return torch.nn.ModuleDict(result_dict)

    def apply_process(self, data, mod, args=[], target_cols=None, collate_fn=None):
        result_dict = {}
        target_cols = self.total_cols if target_cols is None else target_cols
        for col in target_cols:
            result_dict[col] = mod[col](col, data[col], *args)
        
        if collate_fn is not None:
            return collate_fn(result_dict)
        else:
            return result_dict

In [11]:
model = Transformer(data_info, train_dataset.label_encoder_dict,
                        d_model, num_layers, nhead, d_ff, dropout, "gelu")
model.to(device)
summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

temporal_attn_output: torch.Size([16, 100, 4, 256])
temporal_attn_weight: torch.Size([16, 4, 100, 4, 67])


TypeError: cannot unpack non-iterable NoneType object

# Train

In [None]:
for name, param in model.named_parameters():
    if "img_model" in name:
        param.requires_grad = False
    elif "nlp_model" in name:
        param.requires_grad = False

In [None]:
import warnings
warnings.filterwarnings("ignore")

def patchify(imgs, patch_size=16):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = patch_size
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
    return x

def unpatchify(x, patch_size=16):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = patch_size
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs