In [1]:
import joblib

# Torch-related
import torch
from pytorch_model_summary import summary

# Custom defined
from config import fine_tuning
from libs.data import load_dataset, collate_fn, Dataset
from architecture.architecture import MaskedBlockAutoencoder
from architecture.shared_module import patchify, unpatchify

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
is_test_mode = True
is_new_rawdata = False
is_new_dataset = False
config = fine_tuning
# device = torch.device("cuda")
device = torch.device("cpu")

if is_new_dataset:
    train_dataset = load_dataset(is_test_mode, is_new_rawdata, config, mode="fine_tuning", verbose=True)
else:
    suffix = "_test" if is_test_mode else ""
    train_dataset = torch.load(f"src/fine_tuning_dataset{suffix}")

# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, config), pin_memory=True, num_workers=16, prefetch_factor=32)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, config))
for _ in train_dataloader:
    [print(key, val.shape) for key, val in _.items() if "scaler" not in key and "raw" not in key]
    break

sales torch.Size([2, 146, 1])
day torch.Size([2, 146])
dow torch.Size([2, 146])
month torch.Size([2, 146])
holiday torch.Size([2, 146])
price torch.Size([2, 146, 1])
temporal_padding_mask torch.Size([2, 146])
target_fcst_mask torch.Size([2, 146])
img_path torch.Size([2, 3, 224, 224])
detail_desc torch.Size([2, 7])
detail_desc_revert_padding_mask torch.Size([2, 8])
detail_desc_remain_idx torch.Size([2, 7])
detail_desc_masked_idx torch.Size([2, 0])
detail_desc_revert_idx torch.Size([2, 7])


In [3]:
path = "saved_model_epoch9_2024-05-24 19:04:47.683949"
label_encoder_dict = joblib.load("./src/label_encoder_dict.pkl")

mbae_model = MaskedBlockAutoencoder(config, label_encoder_dict)
# mbae_model.load_state_dict(torch.load(path))

class Forecaster(torch.nn.Module):
    def __init__(self, config, mbae_model):
        super().__init__()
        self.config = config
        self.mbae_encoder = mbae_model.encoder
    
    def forward(self, data_input, device):
        data_dict, idx_dict, mask_dict = self.to_gpu(data_input, device)
        self.mbae_encoder(data_dict, idx_dict, mask_dict, device)
        return
    
    def to_gpu(self, data_input, device):
        data_dict, idx_dict, mask_dict = {}, {}, {}
        data_cols = self.config.temporal_cols + self.config.img_cols + self.config.nlp_cols
        for key, val in data_input.items():
            if key in data_cols:
                data_dict[key] = data_input[key].to(device)
            elif key.endswith("idx"):
                idx_dict[key] = data_input[key].to(device)
            elif key.endswith("mask"):
                mask_dict[key] = data_input[key].to(device)
            
        return data_dict, idx_dict, mask_dict

model  = Forecaster(config, mbae_model)
model.to(device)
summary(model, _, device, show_parent_layers=True, print_summary=True)
""

---------------------------------------------------------------------------------------
   Parent Layers       Layer (type)        Output Shape         Param #     Tr. Param #
      Forecaster      MBAEEncoder-1                          19,232,640      19,157,376
Total params: 19,232,640
Trainable params: 19,157,376
Non-trainable params: 75,264
---------------------------------------------------------------------------------------


''