In [1]:
import joblib

# Torch-related
import torch
from pytorch_model_summary import summary

# Custom defined
from config import fine_tuning
from libs.data import load_dataset, collate_fn
from architecture.architecture import MaskedBlockAutoencoder
from architecture.shared_module import patchify, unpatchify

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
is_test_mode = False
is_new_dataset = True
config = fine_tuning
device = torch.device("cuda")
# device = torch.device("cpu")

if is_new_dataset:
    train_dataset = load_dataset(is_test_mode, None, config, mode="fine_tuning", verbose=True)
else:
    suffix = "_test" if is_test_mode else ""
    train_dataset = torch.load(f"src/fine_tuning_train_dataset{suffix}")

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, config), pin_memory=True, num_workers=16, prefetch_factor=32)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, config))
for _ in train_dataloader:
    [print(key, val.shape) for key, val in _.items() if "scaler" not in key and "raw" not in key]
    break

100%|██████████| 2/2 [00:00<00:00, 2381.77it/s]
100%|██████████| 112213/112213 [00:04<00:00, 26581.59it/s]


sales torch.Size([32, 300, 1])
day torch.Size([32, 300])
dow torch.Size([32, 300])
month torch.Size([32, 300])
holiday torch.Size([32, 300])
price torch.Size([32, 300, 1])
target_fcst_mask torch.Size([32, 300])
temporal_padding_mask torch.Size([32, 300])
img_path torch.Size([32, 3, 224, 224])
detail_desc torch.Size([32, 82])
detail_desc_revert_padding_mask torch.Size([32, 83])
detail_desc_remain_idx torch.Size([32, 82])
detail_desc_masked_idx torch.Size([32, 0])
detail_desc_revert_idx torch.Size([32, 82])
information torch.Size([32, 32])
information_revert_padding_mask torch.Size([32, 33])
information_remain_idx torch.Size([32, 32])
information_masked_idx torch.Size([32, 0])
information_revert_idx torch.Size([32, 32])


In [3]:
path = "saved_model_epoch2_2024-05-08 12:48:39.442299"
label_encoder_dict = joblib.load("./src/label_encoder_dict.pkl")

mbae_encoder = MaskedBlockAutoencoder(config, label_encoder_dict)
mbae_encoder.load_state_dict(torch.load(path))
# mbae_encoder = mbae_encoder.encoder

<All keys matched successfully>

In [4]:
mbae_encoder.to(device)
""

''

In [5]:
import os
from tqdm import tqdm
from collections import defaultdict
from IPython.display import clear_output
from transformers import AutoTokenizer
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

mse_loss = torch.nn.MSELoss(reduction="none")
ce_loss = torch.nn.CrossEntropyLoss(reduction="none")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

def get_loss(pred_dict, y_dict, idx_dict, padding_mask_dict):
    loss_dict = {}
    loss_sum, cnt = 0, 0

    for n, (key, pred) in enumerate(pred_dict.items()):
        y = y_dict[key].to(device)

        # Compute loss
        ### Temporal loss
        if key in config.temporal_cols:
            if key in config.scaling_cols:
                loss = mse_loss(pred, y).squeeze()
            elif key in config.embedding_cols:
                loss = ce_loss(pred.view(-1, pred.shape[-1]), y.view(-1).to(torch.long))
                loss = loss.view(y.shape)
        ### Img loss
        elif key in config.img_cols:
            pred = pred[:, 1:, :]
            y = patchify(y, config.patch_size)
            loss = mse_loss(pred, y)
        ### Nlp loss
        elif key in config.nlp_cols:
            pred = pred[:, 1:, :]
            loss = ce_loss(pred.reshape(-1, pred.shape[-1]), y.reshape(-1).to(torch.long))
            loss = loss.view(y.shape)
        
        # Masking loss
        ### Temporal
        if key in config.temporal_cols:
            masked_idx = idx_dict["temporal_masked_idx"]
            padding_mask = padding_mask_dict["temporal_padding_mask"]

            total_mask = torch.where((padding_mask==1), 1, 0)
            loss *= total_mask
            loss_sum += loss.sum(); cnt += total_mask.sum()
            loss = loss.sum()/total_mask.sum()
        ### Img
        elif key in config.img_cols:
            loss_sum += loss.sum(); cnt += loss.shape[0]*loss.shape[1]*loss.shape[2]
            loss = loss.mean()
        ### Nlp
        elif key in config.nlp_cols:
            padding_mask = padding_mask_dict[f"{key}_revert_padding_mask"][:, 1:]
            loss *= padding_mask

            loss_sum += loss.sum(); cnt += total_mask.sum()
            loss = loss.sum() / padding_mask.sum()
            
        loss_dict[key] = loss
    
    total_loss = loss_sum / cnt
    return loss_dict, total_loss

def obtain_loss_dict_for_plot(total_loss, loss_dict, loss_li_dict, mean_loss_li_dict):
    loss_li_dict["total"].append(total_loss.item())
    mean_loss_li_dict["total"].append(np.array(loss_li_dict["total"]).mean())

    for key, val in loss_dict.items():
        loss_li_dict[key].append(val.item())
        mean_loss_li_dict[key].append(np.array(loss_li_dict[key]).mean())

    return loss_li_dict, mean_loss_li_dict

def plot_sample(nrows, ncols, config, mean_loss_li_dict, output_dict, data_dict, decoding_weight_dict):
    idx, plot_idx = 0, 1
    for key, val in mean_loss_li_dict.items():
        # Individual loss
        plt.subplot(nrows, ncols, plot_idx)
        plt.plot(val)
        plt.title(f"{key}: {val[-1]}")
        plot_idx += 1; 
        if key=="total": 
            plot_idx += 3; continue

        pred, y = output_dict[key].detach().cpu().squeeze(), data_dict[key].squeeze()
        length_dict = {"temporal" if key in config.temporal_cols else key :val.shape[1] for key, val in decoding_weight_dict.items()}
        
        # Temporal sample
        if key in config.temporal_cols:
            ### Sample
            if key in config.embedding_cols: pred = torch.argmax(pred, dim=-1)
            plt.subplot(nrows, ncols, plot_idx)
            plt.plot(y[idx]); plt.plot(pred[idx])
            ### Weight
            decoder_weight = decoding_weight_dict[key][idx].mean(dim=0).detach().cpu()
            # decoder_weight = decoding_weight_dict[key][idx].min(dim=0).values.detach().cpu()
            ###### Temporal
            img_decoder_weight = decoder_weight[length_dict["temporal"]:length_dict["temporal"]+length_dict["img_path"]]
            img_decoder_weight = img_decoder_weight[1:]
            plt.subplot(nrows, ncols, plot_idx+1)
            plt.imshow(img_decoder_weight.reshape(224//config.patch_size, 224//config.patch_size))
            ###### Nlp
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
            nlp_decoder_weight1 = decoder_weight[length_dict["temporal"]+length_dict["img_path"]:length_dict["temporal"]+length_dict["img_path"]+length_dict["detail_desc"]]
            nlp_decoder_weight2 = decoder_weight[length_dict["temporal"]+length_dict["img_path"]+length_dict["detail_desc"]:]
            nlp_decoder_weight = torch.cat([nlp_decoder_weight1[1:], nlp_decoder_weight2[1:]], dim=-1)

            nlp1 = tokenizer.tokenize(tokenizer.decode(data_dict["detail_desc"][idx]))
            nlp2 = tokenizer.tokenize(tokenizer.decode(data_dict["information"][idx]))
            # nlp2 = []
            text = nlp1 + nlp2

            df = pd.DataFrame({"text":text, "weight":nlp_decoder_weight})
            df = df[df["text"]!="[PAD]"]
            plt.subplot(nrows, ncols, plot_idx+2)
            sns.barplot(df["weight"])
            plt.gca().set_xticklabels(df["text"], rotation=90)

        # Img sample
        elif key in config.img_cols:
            pred = unpatchify(pred[:, 1:, :]).permute(0,2,3,1)
            y = y.permute(0,2,3,1)
            
            plt.subplot(nrows, ncols, plot_idx)
            plt.imshow(data_dict["img_path_raw"][idx].permute(1,2,0))

            plt.subplot(nrows, ncols, plot_idx+1)
            plt.imshow(y[idx])

            plt.subplot(nrows, ncols, plot_idx+2)
            plt.imshow(pred[idx])

        # Nlp sample
        elif key in config.nlp_cols:
            pred = tokenizer.decode(torch.argmax(pred, dim=-1)[idx])
            y = tokenizer.decode(y[idx])
        
        plot_idx += 3

    plt.tight_layout()
    plt.show()

def train_epoch(model, optimizer, dataloader, config, e):
    pbar = tqdm(dataloader)
    loss_li_dict, mean_loss_li_dict = defaultdict(list), defaultdict(list)
    model.eval()

    for n, data in enumerate(pbar):
        with torch.no_grad():
            decoding_output_dict, encoding_weight_dict, decoding_weight_dict, idx_dict, padding_mask_dict = mbae_encoder(data, device)
        loss_dict, loss = get_loss(decoding_output_dict, data, idx_dict, padding_mask_dict)

        # Plot
        if n % 20 == 0:
            nrows, ncols = 11, 4
            plt.figure(figsize=(25,25))
            clear_output(wait=True)

            loss_li_dict, mean_loss_li_dict = obtain_loss_dict_for_plot(loss, loss_dict, loss_li_dict, mean_loss_li_dict)
            plot_sample(nrows, ncols, config, mean_loss_li_dict, decoding_output_dict, data, decoding_weight_dict)

1==1

True

In [6]:
optimizer = torch.optim.AdamW(mbae_encoder.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
epoch = 10

epoch_loss = {}
for e in range(epoch):
    loss = train_epoch(mbae_encoder, optimizer, train_dataloader, config, e)
    scheduler.step()
    epoch_loss[e] = loss

    # # Save model
    # if not is_test_mode:
    #     now = datetime.datetime.now()
    #     path = f"./saved_model_epoch{e}_{now}"
    #     torch.save(mbae_encoder.state_dict(), path)

print(epoch_loss)

  plt.gca().set_xticklabels(df["text"], rotation=90)
  plt.gca().set_xticklabels(df["text"], rotation=90)
  plt.gca().set_xticklabels(df["text"], rotation=90)
  plt.gca().set_xticklabels(df["text"], rotation=90)
  plt.gca().set_xticklabels(df["text"], rotation=90)
  plt.gca().set_xticklabels(df["text"], rotation=90)
