### Import

In [1]:
import joblib

# Preprocessing
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler

# Torch-related
import torch
from pytorch_model_summary import summary

# Custom defined
from libs.raw_data import *
from libs.dataset import *
from architecture.architecture import *

device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
is_test_mode = False
mode = "pre-train"
# mode = "fine_tuning"
is_from_pretrained = True

# Raw data
is_prep_data_exist = True

# Data loader
MIN_MEANINGFUL_SEQ_LEN = 100
MAX_SEQ_LEN = 100
PRED_LEN = 50

modality_info = {
    "group": ["article_id", "sales_channel_id"],
    "target": ["sales"],
    "temporal": ["day", "dow", "month", "holiday", "price"],
    "img": ["img_path"],
    "nlp": ["detail_desc"]
}
processing_info = {
    "scaling_cols": {"sales": StandardScaler, "price": StandardScaler},
    "embedding_cols": ["day",  "dow", "month", "holiday"],
    "img_cols": ["img_path"],
    "nlp_cols": ["detail_desc"]
}

# Model
batch_size = 16
nhead = 4
dropout = 0.1
patch_size = 16

d_model = {"encoder":64, "decoder":32}
d_ff = {"encoder":64, "decoder":32}
num_layers = {"encoder":2, "decoder":2}
remain_rto = {"temporal":0.25, "img":0.25, "nlp":0.25}
# remain_rto = {"temporal":1., "img":1., "nlp":1.}

# Data

### Load data

In [3]:
data_info = DataInfo(modality_info, processing_info)
df_prep = get_raw_data(is_test_mode, is_prep_data_exist)

if mode == "pre-train":
    df_train = df_prep[(df_prep["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN)
                        &(df_prep["time_idx"] <= MAX_SEQ_LEN-1)
                        &(~pd.isna(df_prep["detail_desc"]))]

### Make dataset

In [4]:
if mode == "pre-train":
    train_dataset = Dataset(df_train, data_info, remain_rto)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
    joblib.dump(train_dataset.label_encoder_dict, "./src/label_encoder_dict.pkl")
    for data in train_dataloader:
        [print(key, val.shape) for key, val in data.items() if "scaler" not in key and "raw" not in key]
        break

# Pre-train

In [5]:
from tqdm import tqdm; tqdm.pandas()
from collections import defaultdict
from IPython.display import clear_output

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

mse_loss = torch.nn.MSELoss(reduction="none")
ce_loss = torch.nn.CrossEntropyLoss(reduction="none")

def patchify(imgs, patch_size=16):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p = patch_size
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
    return x

def unpatchify(x, patch_size=16):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = patch_size
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs


def get_loss(output_dict, data_dict, idx_dict, padding_mask_dict, col_info, data_info):
    loss_dict = {}
    temporal_cols, img_cols, nlp_cols = col_info
    
    for n, ((key, pred), (key_, y)) in enumerate(zip(output_dict.items(), data_dict.items())):
        assert key == key_

        # Temporal loss
        if key in temporal_cols:
            ### Compute loss
            if key in data_info.processing_info["scaling_cols"]:
                loss = mse_loss(pred, y).squeeze()
            elif key in data_info.processing_info["embedding_cols"]:
                loss = ce_loss(pred.view(-1, pred.shape[-1]), y.view(-1).to(torch.long))
                loss = loss.view(y.shape)
            ### Apply mask
            padding_mask = padding_mask_dict["temporal_padding_mask"]
            masked_idx = idx_dict["temporal_block_masked_idx"]
            masking_mask = (masked_idx==1).sum(dim=-1)
            total_mask = torch.where((padding_mask==1) & (masking_mask==1), 1, 0)
            loss *= total_mask
            loss = loss.sum() / total_mask.sum()

        # Img loss
        elif key in img_cols:
            ### Compute loss
            pred = pred[:, 1:, :]
            y = patchify(y, patch_size)
            loss = mse_loss(pred, y)
            ### Apply mask
            masked_idx = idx_dict[f"{key}_masked_idx"]
            masked_idx = masked_idx.unsqueeze(-1).repeat(1, 1, loss.shape[-1])
            loss = torch.gather(loss, index=masked_idx, dim=1)
            loss = loss.mean()

        ### Nlp loss
        elif key in nlp_cols:
            ### Compute loss
            pred = pred[:, 1:, :]
            y = y[:, 1:]
            loss = ce_loss(pred.reshape(-1, pred.shape[-1]), y.reshape(-1).to(torch.long))
            loss = loss.view(y.shape)
            ### Apply mask
            masked_idx = idx_dict[f"{key}_masked_idx"]
            masked_loss = torch.gather(loss, index=masked_idx, dim=1)

            padding_mask = padding_mask_dict[f"{key}_masked_padding_mask"][:, 1:]

            loss = masked_loss * padding_mask
            loss = loss.sum() / padding_mask.sum()
        
        loss_dict[key] = loss
    
    return loss_dict


def obtain_loss_dict_for_plot(total_loss, loss_dict, loss_li_dict, mean_loss_li_dict):
    loss_li_dict["total"].append(total_loss.item())
    mean_loss_li_dict["total"].append(np.array(loss_li_dict["total"]).mean())

    for key, val in loss_dict.items():
        loss_li_dict[key].append(val.item())
        mean_loss_li_dict[key].append(np.array(loss_li_dict[key]).mean())

    return loss_li_dict, mean_loss_li_dict

def plot_loss_sample(nrows, ncols, idx, mean_loss_li_dict, output_dict, data_dict, col_info, patch_size):
    plot_idx = 1
    for key, val in mean_loss_li_dict.items():
        # Individual loss
        plt.subplot(nrows, ncols, plot_idx)
        plt.plot(val)
        plt.title(f"{key}: {mean_loss_li_dict[key][-1]}")
        plot_idx += 1

        # Sample
        temporal_cols, img_cols, nlp_cols = col_info
        ### Temporal
        if key in temporal_cols:
            pred, y = output_dict[key].squeeze(), data_dict[key].squeeze()
            if key in data_info.processing_info["embedding_cols"]: pred = torch.argmax(pred, dim=-1)
            
            plt.subplot(nrows, ncols, plot_idx)
            plt.plot(y[idx].detach().cpu())
            plt.plot(pred[idx].detach().cpu())
        ### Img
        elif key in img_cols:
            pred, y = output_dict[key], data_dict[key]
            pred = unpatchify(pred[:, 1:, :], patch_size).permute(0,2,3,1)
            y = y.permute(0,2,3,1)

            plt.subplot(nrows, ncols, plot_idx)
            plt.imshow(y[idx].detach().cpu())
            
            plt.subplot(nrows, ncols, plot_idx+1)
            plt.imshow(pred[idx].detach().cpu())
        
        plot_idx += 2

        
    plt.tight_layout()        
    plt.show()


def pre_train(e, data_loader, optimizer, model, data_info, patch_size):
    pbar = tqdm(data_loader)
    loss_li_dict, mean_loss_li_dict = defaultdict(list), defaultdict(list)

    for n, data in enumerate(pbar):
        optimizer.zero_grad()
        model.train()
        output_dict, data_dict, idx_dict, padding_mask_dict = model(data, remain_rto, device)

        col_info = model.define_col_modalities(data_info)
        loss_dict = get_loss(output_dict, data_dict, idx_dict, padding_mask_dict, col_info, data_info)
        total_loss = torch.stack(list(loss_dict.values())).sum()
        total_loss.backward()
        optimizer.step()

        # Plot
        if n % 20 == 0:
            nrows, ncols, idx = 10, 3, 0
            plt.figure(figsize=(15,15))
            clear_output(wait=True)
            
            loss_li_dict, mean_loss_li_dict = obtain_loss_dict_for_plot(total_loss, loss_dict, loss_li_dict, mean_loss_li_dict)
            plot_loss_sample(nrows, ncols, idx, mean_loss_li_dict, output_dict, data_dict, col_info, patch_size)

1==1

True

In [6]:
if mode == "pre-train":
    # Define model
    label_encoder_dict = joblib.load("./src/label_encoder_dict.pkl")
    model = MaskedBlockAutoEncoder(data_info, label_encoder_dict,
                            d_model, num_layers, nhead, d_ff, dropout, "gelu",
                            patch_size, is_from_pretrained)
    model.to(device)
    # summary(model, data, remain_rto, device, show_parent_layers=True, print_summary=True)

    # for name, param in model.named_parameters():
    # if "img_model" in name:
    #     param.requires_grad = False
    # elif "nlp_model" in name:
    #     param.requires_grad = False

    # Train
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    epoch = 3
    
    for e in range(epoch):
        pre_train(e, train_dataloader, optimizer, model, data_info, patch_size)
        scheduler.step()

In [7]:
if mode == "pre-train":
    import datetime
    now = datetime.datetime.now()
    path = f"./saved_model_{now}"
    print(path)
    torch.save(model.state_dict(), path)

# Need Positional Encoding For Forecasters

# Fine-Tuning

In [8]:
path = './saved_model_2024-04-26 20:43:39.549405'

label_encoder_dict = joblib.load("./src/label_encoder_dict.pkl")
mbae = MaskedBlockAutoEncoder(data_info, label_encoder_dict,
                            d_model, num_layers, nhead, d_ff, dropout, "gelu",
                            patch_size, is_from_pretrained=True)
mbae.load_state_dict(torch.load(path))
# mbae_encoder = mbae.mbae_encoder
# mbae_encoder.to(device)
model = mbae
model.to(device)

remain_rto = {"temporal":1, "img":1, "nlp":1}



In [9]:
df_valid = df_prep[(df_prep["meaningful_size"] >= MIN_MEANINGFUL_SEQ_LEN)
                    &(df_prep["time_idx"] <= MAX_SEQ_LEN-1+PRED_LEN)
                    &(~pd.isna(df_prep["detail_desc"]))]

In [10]:
from tqdm import tqdm; tqdm.pandas()

# Sklearn-related
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np

# Torch-related
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from PIL import Image

transform_img = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

class ValidDataset(torch.utils.data.Dataset):
    def __init__(self, data, data_info, remain_rto, label_encoder_dict):
        super().__init__()
        self.data_info, self.remain_rto = data_info, remain_rto
        self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

        # Fit label encoder
        self.label_encoder_dict = label_encoder_dict

        # Iterate data
        data_li = []
        data.groupby(self.data_info.modality_info["group"]).progress_apply(lambda x: data_li.append(x))
        self.dataset = tuple(data_li)

    def transform_label_encoder(self, data):
        result_dict = {}
        embedding_cols = self.data_info.processing_info["embedding_cols"]
        for col in embedding_cols:
            result_dict[col] = self.label_encoder_dict[col].transform(data[col].values)
        return result_dict
    
    def scale_data(self, data):
        result_dict = {}
        scaling_cols = self.data_info.processing_info["scaling_cols"]
        for col, scaler in scaling_cols.items():
            scaler = scaler()
            result_dict[col] = scaler.fit_transform(data[col].values.reshape(-1,1))
            result_dict[f"{col}_scaler"] = scaler
        return result_dict

    def apply_nlp_remain(self, data):
        result_dict = {}
        nlp_cols = self.data_info.modality_info["nlp"]
        remain_rto = self.remain_rto["nlp"]

        for col in nlp_cols:
            nlp_data = set(data[col].values); assert len(nlp_data) == 1
            total_token = self.tokenizer(next(iter(nlp_data)), return_tensors="np")["input_ids"].squeeze()
            
            # Split global token and valid token
            global_token = total_token[:1]
            valid_token = total_token[1:]
            
            # Get remain/masked/revert indices
            valid_token_shape = valid_token.shape; assert len(valid_token_shape)==1
            
            num_remain = int(valid_token_shape[0] * remain_rto)
            noise = np.random.rand(valid_token_shape[0])
            shuffle_idx = np.argsort(noise)
            
            remain_idx = shuffle_idx[:num_remain]
            masked_idx = shuffle_idx[num_remain:]
            revert_idx = np.argsort(shuffle_idx)
            
            remain_padding_mask = np.ones(remain_idx.shape[0]+1)
            masked_padding_mask = np.ones(masked_idx.shape[0]+1)
            revert_padding_mask = np.ones(revert_idx.shape[0]+1)

            result_dict.update({f"{col}":total_token, f"{col}_raw":nlp_data,
                                f"{col}_remain_idx":remain_idx, f"{col}_masked_idx":masked_idx, f"{col}_revert_idx":revert_idx,
                                f"{col}_remain_padding_mask":remain_padding_mask, f"{col}_masked_padding_mask":masked_padding_mask, f"{col}_revert_padding_mask":revert_padding_mask})

        return result_dict

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        result_dict = {}
        data = self.dataset[idx]
        
        # Laben encode and scale data
        embedding_cols = self.transform_label_encoder(data)
        scaling_cols = self.scale_data(data)
        result_dict.update(**embedding_cols, **scaling_cols)
        
        # Temporal forecast/padding mask
        target_col = self.data_info.modality_info["target"]
        target_fcst_mask = np.ones(data[target_col].shape).squeeze()
        target_fcst_mask[-PRED_LEN:] = 0
        
        temporal_padding_mask  = np.ones(data[target_col].shape).squeeze()
        result_dict.update({"target_fcst_mask":target_fcst_mask, "temporal_padding_mask":temporal_padding_mask})

        # Img
        img_cols = self.data_info.modality_info["img"]
        for col in img_cols:
            img_path = set(data[col].values); assert len(img_path) == 1
            img_raw = Image.open(next(iter(img_path))).convert("RGB")
            result_dict[f"{col}_raw"] = img_raw
        
        # Nlp
        nlp_result_dict = self.apply_nlp_remain(data)
        result_dict.update(nlp_result_dict)

        return result_dict

def valid_collate_fn(batch_li, data_info):
    result_dict = {}
    # Temporal
    target_temporal_cols = data_info.modality_info["target"] + data_info.modality_info["temporal"]
    for col in target_temporal_cols:
        tensor_type = torch.int if col in data_info.processing_info["embedding_cols"] else torch.float
        data = [torch.from_numpy(batch[col]).to(tensor_type) for batch in batch_li]
        result_dict[col] = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)
    ### Temporal masks
    target_fcst_mask = [torch.from_numpy(batch["target_fcst_mask"]) for batch in batch_li]
    temporal_padding_mask = [torch.from_numpy(batch["temporal_padding_mask"]) for batch in batch_li]
    
    result_dict["target_fcst_mask"] = torch.nn.utils.rnn.pad_sequence(target_fcst_mask, batch_first=True)
    result_dict["temporal_padding_mask"] = torch.nn.utils.rnn.pad_sequence(temporal_padding_mask, batch_first=True)

    # Img
    img_cols = data_info.modality_info["img"]
    for col in img_cols:
        img_raw = [batch[f"{col}_raw"] for batch in batch_li]
        img_data = transform_img(img_raw, return_tensors="pt")["pixel_values"]

        result_dict[f"{col}_raw"] = img_raw
        result_dict[f"{col}"] = img_data
    
    # Nlp
    nlp_cols = data_info.modality_info["nlp"]
    for col in nlp_cols:
        data = [torch.from_numpy(batch[col]).to(torch.int) for batch in batch_li]
        data_remain_idx = [torch.from_numpy(batch[f"{col}_remain_idx"]).to(torch.int64) for batch in batch_li]
        data_masked_idx = [torch.from_numpy(batch[f"{col}_masked_idx"]).to(torch.int64) for batch in batch_li]
        data_revert_idx = [torch.from_numpy(batch[f"{col}_revert_idx"]).to(torch.int64) for batch in batch_li]
        data_remain_padding_mask = [torch.from_numpy(batch[f"{col}_remain_padding_mask"]).to(tensor_type) for batch in batch_li]
        data_masked_padding_mask = [torch.from_numpy(batch[f"{col}_masked_padding_mask"]).to(tensor_type) for batch in batch_li]
        data_revert_padding_mask = [torch.from_numpy(batch[f"{col}_revert_padding_mask"]).to(tensor_type) for batch in batch_li]

        result_dict[col] = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)
        result_dict[f"{col}_remain_idx"] = torch.nn.utils.rnn.pad_sequence(data_remain_idx, batch_first=True)
        result_dict[f"{col}_masked_idx"] = torch.nn.utils.rnn.pad_sequence(data_masked_idx, batch_first=True)
        result_dict[f"{col}_revert_idx"] = torch.nn.utils.rnn.pad_sequence(data_revert_idx, batch_first=True)
        result_dict[f"{col}_remain_padding_mask"] = torch.nn.utils.rnn.pad_sequence(data_remain_padding_mask, batch_first=True)
        result_dict[f"{col}_masked_padding_mask"] = torch.nn.utils.rnn.pad_sequence(data_masked_padding_mask, batch_first=True)
        result_dict[f"{col}_revert_padding_mask"] = torch.nn.utils.rnn.pad_sequence(data_revert_padding_mask, batch_first=True)

        result_dict[f"{col}_raw"] = [batch[f"{col}_raw"] for batch in batch_li]

    return result_dict

1==1

True

In [11]:
valid_dataset = ValidDataset(df_valid, data_info, remain_rto, label_encoder_dict)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: valid_collate_fn(x, data_info), pin_memory=True, num_workers=16, prefetch_factor=4)
for valid_data in valid_dataloader:
    [print(key, val.shape) for key, val in valid_data.items() if "scaler" not in key and "raw" not in key]
    break

100%|██████████| 29274/29274 [00:01<00:00, 28429.63it/s]


sales torch.Size([16, 150, 1])
day torch.Size([16, 150])
dow torch.Size([16, 150])
month torch.Size([16, 150])
holiday torch.Size([16, 150])
price torch.Size([16, 150, 1])
target_fcst_mask torch.Size([16, 150])
temporal_padding_mask torch.Size([16, 150])
img_path torch.Size([16, 3, 224, 224])
detail_desc torch.Size([16, 61])
detail_desc_remain_idx torch.Size([16, 60])
detail_desc_masked_idx torch.Size([16, 0])
detail_desc_revert_idx torch.Size([16, 60])
detail_desc_remain_padding_mask torch.Size([16, 61])
detail_desc_masked_padding_mask torch.Size([16, 1])
detail_desc_revert_padding_mask torch.Size([16, 61])


In [12]:
# from torch.nn import functional as F

# def _generate_square_subsequent_mask(sz, device, dtype):
#     return torch.triu(
#         torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
#         diagonal=1,
#     )

# def _get_seq_len(src, batch_first):
#     if src.is_nested:
#         return None
#     else:
#         src_size = src.size()
#         if len(src_size) == 2:
#             # unbatched: S, E
#             return src_size[0]
#         else:
#             # batched: B, S, E if batch_first else S, B, E
#             seq_len_pos = 1 if batch_first else 0
#             return src_size[seq_len_pos]

# def _detect_is_causal_mask(mask, is_causal=None,size=None):
#     # Prevent type refinement
#     make_causal = (is_causal is True)

#     if is_causal is None and mask is not None:
#         sz = size if size is not None else mask.size(-2)
#         causal_comparison = _generate_square_subsequent_mask(
#             sz, device=mask.device, dtype=mask.dtype)

#         # Do not use `torch.equal` so we handle batched masks by
#         # broadcasting the comparison.
#         if mask.size() == causal_comparison.size():
#             make_causal = bool((mask == causal_comparison).all())
#         else:
#             make_causal = False

#     return make_causal

# class EncoderLayer_(torch.nn.TransformerEncoderLayer):
#     def forward(self, src, pos_enc, src_mask=None, src_key_padding_mask=None, is_causal=False):
#         x = src
#         attn_output, attn_weight = self._sa_block(x, pos_enc, src_mask, src_key_padding_mask, is_causal=is_causal)
#         x = self.norm1(x + attn_output)
#         x = self.norm2(x + self._ff_block(x))

#         return x, attn_weight

#     # self-attention block
#     def _sa_block(self, x, pos_enc, attn_mask, key_padding_mask, is_causal=False):
#         x, attn_weight = self.self_attn(x+pos_enc, x+pos_enc, x,
#                            attn_mask=attn_mask,
#                            key_padding_mask=key_padding_mask,
#                            need_weights=True, is_causal=is_causal, average_attn_weights=False)
#         return self.dropout1(x), attn_weight

# class Encoder_(torch.nn.TransformerEncoder):
#     def forward(self, src, pos_enc=0, mask=None, src_key_padding_mask=None, is_causal=None):
#        ################################################################################################################
#         src_key_padding_mask = F._canonical_mask(
#             mask=src_key_padding_mask,
#             mask_name="src_key_padding_mask",
#             other_type=F._none_or_dtype(mask),
#             other_name="mask",
#             target_type=src.dtype
#         )

#         mask = F._canonical_mask(
#             mask=mask,
#             mask_name="mask",
#             other_type=None,
#             other_name="",
#             target_type=src.dtype,
#             check_other=False,
#         )

#         output = src
#         convert_to_nested = False
#         first_layer = self.layers[0]
#         src_key_padding_mask_for_layers = src_key_padding_mask
#         why_not_sparsity_fast_path = ''
#         str_first_layer = "self.layers[0]"
#         batch_first = first_layer.self_attn.batch_first
#         if not hasattr(self, "use_nested_tensor"):
#             why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
#         elif not self.use_nested_tensor:
#             why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
#         elif first_layer.training:
#             why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
#         elif not src.dim() == 3:
#             why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
#         elif src_key_padding_mask is None:
#             why_not_sparsity_fast_path = "src_key_padding_mask was None"
#         elif (((not hasattr(self, "mask_check")) or self.mask_check)
#                 and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
#             why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
#         elif output.is_nested:
#             why_not_sparsity_fast_path = "NestedTensor input is not supported"
#         elif mask is not None:
#             why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
#         elif torch.is_autocast_enabled():
#             why_not_sparsity_fast_path = "autocast is enabled"

#         if not why_not_sparsity_fast_path:
#             tensor_args = (
#                 src,
#                 first_layer.self_attn.in_proj_weight,
#                 first_layer.self_attn.in_proj_bias,
#                 first_layer.self_attn.out_proj.weight,
#                 first_layer.self_attn.out_proj.bias,
#                 first_layer.norm1.weight,
#                 first_layer.norm1.bias,
#                 first_layer.norm2.weight,
#                 first_layer.norm2.bias,
#                 first_layer.linear1.weight,
#                 first_layer.linear1.bias,
#                 first_layer.linear2.weight,
#                 first_layer.linear2.bias,
#             )
#             _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
#             if torch.overrides.has_torch_function(tensor_args):
#                 why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
#             elif src.device.type not in _supported_device_type:
#                 why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
#             elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
#                 why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
#                                               "input/output projection weights or biases requires_grad")

#             if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
#                 convert_to_nested = True
#                 output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
#                 src_key_padding_mask_for_layers = None

#         seq_len = _get_seq_len(src, batch_first)
#         is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
#        ################################################################################################################

#         for mod in self.layers:
#             output, attn_weight = mod(output, pos_enc, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)

#         if convert_to_nested:
#             output = output.to_padded_tensor(0., src.size())

#         if self.norm is not None:
#             output = self.norm(output)

#         return output, attn_weight

In [13]:
# class Forecaster(torch.nn.Module):
#     def __init__(self, mbae_encoder, activation):
#         super().__init__()
#         self.mbae_encoder = mbae_encoder
#         self.pos_enc = PositionalEncoding(d_model["encoder"], dropout)
#         self.linear1 = torch.nn.Linear(d_model["encoder"], d_model["encoder"]) 
#         # self.forecaster = Encoder_(EncoderLayer_(d_model["decoder"], nhead, d_ff["decoder"], dropout, activation, batch_first=True), num_layers["decoder"])
#         self.forecaster = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model["encoder"], nhead, d_ff["encoder"], dropout, activation, batch_first=True), 1)
#         self.linear2 = torch.nn.Linear(d_model["encoder"], 1)
    
#     def forward(self, data_input, remain_rto, device):
#         encoding_dict, encoding_weight_dict,\
#             data_dict, idx_dict, padding_mask_dict = self.mbae_encoder(data_input, remain_rto, device)
#         encoding = encoding_dict["global"]
#         encoding = self.linear1(encoding)
#         # encoding = self.pos_enc(encoding)

#         padding_mask = torch.where((padding_mask_dict["temporal_padding_mask"].squeeze()==1), 0, -torch.inf)
#         # output, weight = self.forecaster(encoding, src_key_padding_mask=padding_mask)
#         output = self.forecaster(encoding, src_key_padding_mask=padding_mask)
#         weight = None
#         output = self.linear2(output).squeeze()

#         return output, data_dict, padding_mask_dict, encoding_weight_dict, weight

# model = Forecaster(mbae_encoder, "gelu")
# model.to(device)
# summary(model, valid_data, remain_rto, device, show_parent_layers=True, print_summary=True)
# ""

In [14]:
# for name, param in model.named_parameters():
#     if "img_model" in name:
#         param.requires_grad = False
#     elif "nlp_model" in name:
#         param.requires_grad = False
#     # if "mbae_encoder" in name:
#     #     param.requires_grad = False

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
mse_loss = torch.nn.MSELoss(reduction="none")
ce_loss = torch.nn.CrossEntropyLoss(reduction="none")

def train(e):
    pbar = tqdm(valid_dataloader)
    loss_li, mean_loss_li = [], []

    for n, data in enumerate(pbar):
        # Train
        optimizer.zero_grad()
        model.train()
        # output, data_dict, padding_mask_dict, encoding_weight_dict, weight = model(data, remain_rto, device)
        output_dict, data_dict, idx_dict, padding_mask_dict = model(data, remain_rto, device)
        output = output_dict["sales"].squeeze()

        # loss = mse_loss(output, data_dict["sales"].squeeze())
        loss = mse_loss(output, data["sales"].squeeze().to(device))
        mask = torch.where((padding_mask_dict["temporal_padding_mask"].squeeze()==1)&(padding_mask_dict["target_fcst_mask"].squeeze()==0), 1, 0)

        loss *= mask
        loss = loss.sum() / mask.sum()
        loss_li.append(loss.item())
        mean_loss_li.append(np.array(loss_li).mean())
        
        loss.backward()
        optimizer.step()
        # # Plot
        # if n % 20 == 0:
        #     nrows, ncols = 3, 2
        #     plt.figure(figsize=(15, 3))
        #     clear_output(wait=True)
            
        #     # Loss
        #     plt.subplot(nrows, ncols, 1)
        #     plt.plot(mean_loss_li)
        #     plt.title(mean_loss_li[-1])

        #     # Sampleself_attn_weight
        #     plt.subplot(nrows, ncols, 2)
        #     idx = 0
        #     pred = torch.where(mask==1, output, torch.nan)[idx]
        #     y = torch.where(mask==1, data_dict["sales"].squeeze(), torch.nan)[idx]
        #     y = torch.where(mask==1, data["sales"].squeeze().to(device), torch.nan)[idx]

        #     # pred = output[idx]
        #     # y = data_dict["sales"].squeeze()[idx]

        #     pred = pred[~torch.isnan(pred)]
        #     y = y[~torch.isnan(y)]

        #     plt.plot(y.detach().cpu())
        #     plt.plot(pred.detach().cpu())

        #     # Weight
        #     temporal_weight = encoding_weight_dict["temporal"]

        #     mean_temporal_weight = temporal_weight.mean(dim=1)
        #     min_temporal_weight = temporal_weight.min(dim=1).values
            
        #     plot_temporal_weight = min_temporal_weight[idx, :, 0, 7+2:7+2+14*14]
        #     mask = mask[idx].unsqueeze(-1).repeat(1, plot_temporal_weight.shape[-1])
        #     plot_temporal_weight = torch.where(mask==1, plot_temporal_weight, torch.nan)
        #     plot_temporal_weight = plot_temporal_weight[~torch.isnan(plot_temporal_weight)].reshape(-1, plot_temporal_weight.shape[-1])            
            
        #     plt.subplot(nrows, ncols, 3)
        #     # plot_temporal_weight = plot_temporal_weight.min(dim=0).values
        #     plot_temporal_weight = plot_temporal_weight.mean(dim=0)
        #     # plot_temporal_weight = plot_temporal_weight[-1]
        #     plot_temporal_weight = plot_temporal_weight.reshape(224//patch_size,224//patch_size)
        #     plt.imshow(plot_temporal_weight.detach().cpu())

        #     plt.subplot(nrows, ncols, 4)
        #     plt.imshow(data_dict["img_path"].permute(0,2,3,1).detach().cpu()[idx])

        #     # plt.subplot(nrows, ncols, 5)
        #     # sns.heatmap(weight[idx].mean(dim=0).detach().cpu())

        #     plt.tight_layout()
        #     plt.show()


epoch = 1
for e in range(epoch):
    train(e)
    scheduler.step()
    # raise

  0%|          | 2/1830 [00:03<45:50,  1.50s/it]  


OutOfMemoryError: CUDA out of memory. Tried to allocate 330.00 MiB (GPU 0; 23.67 GiB total capacity; 21.83 GiB already allocated; 29.25 MiB free; 23.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF