# Import & Setting

### Import

In [1]:
import os

import numpy as np
import pandas as pd
pd.set_option("display.max_columns", None)

import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from IPython.display import clear_output

import torch
from PIL import Image
from torchvision import transforms

import pytorch_forecasting as pf
from pytorch_forecasting.models.base_model import BaseModelWithCovariates

from transformers import SwinModel

device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


### Setting

In [2]:
# Params for sampling 
num_samples = 100

# Params for Train_test_split 
train_test_split_rto = 0.1

# Dataset
window_size = 24
predict_length = 6
batch_size = 128

# Model
d_model = 128
dropout = 0.3
nhead = 4
num_layers = 4
d_ff = 512
# d_model = 512
# dropout = 0.3
# nhead = 8
# num_layers = 6
# d_ff = 2048

# Preprocess Data

### Read data

In [3]:
df_raw = pd.read_csv("HnM/transactions_train.csv", dtype={"article_id":str}) # Read data
# df_raw = df_raw.iloc[:1000]
df_raw.head()

Unnamed: 0,t_dat,customer_id,article_id,price,sales_channel_id
0,2018-09-20,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,663713001,0.050831,2
1,2018-09-20,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,541518023,0.030492,2
2,2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,505221004,0.015237,2
3,2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687003,0.016932,2
4,2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687004,0.016932,2


### Feature Engineering

In [4]:
df_engineered = df_raw.copy()
print(df_engineered["article_id"].nunique())

# Generate sales
df_engineered["t_dat"] = pd.to_datetime(df_engineered["t_dat"])

df_engineered["year"] = df_engineered["t_dat"].dt.year
df_engineered["month"] = df_engineered["t_dat"].dt.month

df_engineered["t_dat"] = df_engineered["year"].astype(str) + "-" + df_engineered["month"].astype(str) + "-" + "01"
df_engineered["t_dat"] = pd.to_datetime(df_engineered["t_dat"])
df_engineered = df_engineered.groupby(["t_dat", "article_id"], as_index=False).agg(price=("price", "mean"), sales=("customer_id", "count"))

# Explode dates
def func(x):
    full_date = pd.DataFrame(pd.date_range(x["t_dat"].min(), x["t_dat"].max(), freq="MS"), columns=["t_dat"])
    x = x.merge(full_date, on="t_dat", how="right").reset_index(drop=True)
    x["article_id"] = x["article_id"].unique()[0]
    x["sales"] = x["sales"].fillna(0)
    x["price"] = x["price"].fillna(0)
    return x
df_engineered = df_engineered.groupby("article_id", as_index=False).apply(lambda x: func(x)).reset_index(drop=True)
df_engineered["time_idx"] = df_engineered.groupby("article_id").cumcount()

# Deal date
df_engineered["year"] = df_engineered["t_dat"].dt.year
df_engineered["month"] = df_engineered["t_dat"].dt.month - 1
df_engineered["day"] = df_engineered["t_dat"].dt.day - 1
df_engineered["dayofweek"] = df_engineered["t_dat"].dt.dayofweek

# Deal product life
df_engineered["min_date"] = df_engineered.groupby("article_id")["t_dat"].transform("min")
df_engineered["max_date"] = df_engineered.groupby("article_id")["t_dat"].transform("max")
df_engineered["prod_life"] = (df_engineered["max_date"] - df_engineered["min_date"]).dt.days + 1

# Generate transaction date count
df_engineered["trans_date_cnt"] = df_engineered.groupby(["article_id"])["t_dat"].transform("count")

# Generate non-zero dates rto
df_engineered["nonzero_date_rto"] = df_engineered["trans_date_cnt"] / df_engineered["prod_life"]

# Deal image
df_engineered["img_path"] = df_engineered["article_id"].apply(lambda x: f'./HnM/images/{x[:3]}/{x}.jpg') # Generate image path
df_engineered["is_valid"] = df_engineered["img_path"].apply(lambda x: 1 if os.path.isfile(x) else 0) # Check whether the article has corresponding image file

df_engineered.head()

104547


Unnamed: 0,t_dat,article_id,price,sales,time_idx,year,month,day,dayofweek,min_date,max_date,prod_life,trans_date_cnt,nonzero_date_rto,img_path,is_valid
0,2018-09-01,108775015,0.008026,662.0,0,2018,8,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
1,2018-10-01,108775015,0.008081,1532.0,1,2018,9,0,0,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
2,2018-11-01,108775015,0.007889,1660.0,2,2018,10,0,3,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
3,2018-12-01,108775015,0.008277,1207.0,3,2018,11,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
4,2019-01-01,108775015,0.008369,1522.0,4,2019,0,0,1,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1


### Filter

In [5]:
df_filtered = df_engineered.copy()

# Filtering
df_filtered = df_filtered[df_filtered["is_valid"] == 1] # Valid if having corresponding image
df_filtered = df_filtered[df_filtered["prod_life"] >= 1] 
# df_filtered = df_filtered[df_filtered["sales"] >= predict_length + 1] # Length should be at least greater than prediction length

df_filtered.head()

Unnamed: 0,t_dat,article_id,price,sales,time_idx,year,month,day,dayofweek,min_date,max_date,prod_life,trans_date_cnt,nonzero_date_rto,img_path,is_valid
0,2018-09-01,108775015,0.008026,662.0,0,2018,8,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
1,2018-10-01,108775015,0.008081,1532.0,1,2018,9,0,0,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
2,2018-11-01,108775015,0.007889,1660.0,2,2018,10,0,3,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
3,2018-12-01,108775015,0.008277,1207.0,3,2018,11,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
4,2019-01-01,108775015,0.008369,1522.0,4,2019,0,0,1,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1


### Sample data

In [6]:
df_sampled = df_filtered.copy()

# Sample by sales amount
sample_id_li = df_sampled.groupby("article_id").agg({"sales":"sum"}).sort_values("sales", ascending=False) # Sort article_id by number of sales
sample_id_li = sample_id_li.iloc[:num_samples].index if num_samples else sample_id_li.index # Slice article_id
df_sampled = df_sampled[df_sampled["article_id"].isin(sample_id_li)].reset_index(drop=True)

df_sampled = df_sampled.sort_values(["article_id", "t_dat"])

df_sampled.head()

Unnamed: 0,t_dat,article_id,price,sales,time_idx,year,month,day,dayofweek,min_date,max_date,prod_life,trans_date_cnt,nonzero_date_rto,img_path,is_valid
0,2018-09-01,108775015,0.008026,662.0,0,2018,8,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
1,2018-10-01,108775015,0.008081,1532.0,1,2018,9,0,0,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
2,2018-11-01,108775015,0.007889,1660.0,2,2018,10,0,3,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
3,2018-12-01,108775015,0.008277,1207.0,3,2018,11,0,5,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1
4,2019-01-01,108775015,0.008369,1522.0,4,2019,0,0,1,2018-09-01,2020-07-01,670,23,0.034328,./HnM/images/010/0108775015.jpg,1


### Post-sampling

In [7]:
df_post = df_sampled.copy()

### LabelEncode image path
imgpath_encoder = LabelEncoder()
df_post["img_path"] = imgpath_encoder.fit_transform(df_post["img_path"])

# Train test split
num_samples = num_samples if num_samples else df_post["article_id"].nunique()
num_train = int(np.round(num_samples * train_test_split_rto))
sample_id_li_train = sample_id_li[:num_train]

df_train = df_post[df_post["article_id"].isin(sample_id_li_train)].reset_index(drop=True)
df_valid = df_post[~df_post["article_id"].isin(sample_id_li_train)].reset_index(drop=True)
assert df_train.shape[0] + df_valid.shape[0] == df_post.shape[0]

# Make Dataset

In [8]:
from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder
from torch.utils.data._utils.collate import default_collate

train_dataset = pf.TimeSeriesDataSet(
    data=df_train,
    time_idx="time_idx",
    target="sales",
    group_ids=["img_path"],
    # static_reals=["img_path"], # image is a static information which does not change by time
    min_encoder_length=window_size,
    max_encoder_length=window_size,
    min_prediction_idx=predict_length,
    max_prediction_length=predict_length,
    time_varying_unknown_reals=["sales", "month", "day", "dayofweek"],
    target_normalizer=None,
    scalers={
        "month":None, "day":None, "dayofweek":None},
    categorical_encoders={
        "img_path":NaNLabelEncoder(add_nan=True)
        }
)
valid_dataset = pf.TimeSeriesDataSet.from_dataset(train_dataset, df_post, predict=True, stop_randomization=True)

train_dataloader = train_dataset.to_dataloader(batch_size=batch_size, shuffle=True)
valid_dataloader = valid_dataset.to_dataloader(train=False, batch_size=2, shuffle=False, drop_last=True)



AssertionError: filters should not remove entries all entries - check encoder/decoder lengths and lags

In [None]:
import joblib

joblib.dump(train_dataloader, "train_dataloader.pkl")
joblib.dump(valid_dataloader, "valid_dataloader.pkl")
joblib.dump(train_dataset, "train_dataset.pkl")
joblib.dump(imgpath_encoder, "imgpath_encoder.pkl")

['imgpath_encoder.pkl']

# Model

### Architecture

In [None]:
import joblib

train_dataset = joblib.load("train_dataset.pkl")
train_dataloader = joblib.load("train_dataloader.pkl")
valid_dataloader = joblib.load("valid_dataloader.pkl")
imgpath_encoder = joblib.load("imgpath_encoder.pkl")

In [None]:
class PositionalEncoding(torch.nn.Module):
    # PE(pos, 2i) = sin(pos/10000^{2i/d_model}), 
    # PE(pos, 2i+1) = cos(pos/10000^{2i/d_model})
    def __init__(self, max_len, d_model, dropout):
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout)

        position = torch.arange(max_len).reshape(-1,1).to(device)
        i = torch.arange(d_model).to(device)//2
        exp_term = 2*i/d_model
        div_term = torch.pow(10000, exp_term).reshape(1, -1)
        self.pos_encoded = position / div_term

        self.pos_encoded[:, 0::2] = torch.sin(self.pos_encoded[:, 0::2])
        self.pos_encoded[:, 1::2] = torch.cos(self.pos_encoded[:, 1::2])

    def forward(self, x):
        output = x + self.pos_encoded[:x.shape[1], :]
        return self.dropout(output)
    
class Mask(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def get_padding_mask(self, arr):
        res = torch.eq(arr, 0).type(torch.FloatTensor).to(device)
        res = torch.where(res==1, -torch.inf, 0)
        return res
    
    def get_lookahead_mask(self, arr):
        seq_len = arr.shape[1]
        mask = torch.triu(torch.ones((seq_len, seq_len))*-1e-9, 1).to(device)
        return mask

    def forward(self, arr):
        padding_mask = self.get_padding_mask(arr)
        lookahead_mask = self.get_lookahead_mask(arr)
        return padding_mask, lookahead_mask

In [None]:
class MultimodalTransformerDecoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout, batch_first):
        super().__init__()
        # self.swin_transformer = swin_transformer
        self.attn = torch.nn.MultiheadAttention(d_model, nhead, dropout, batch_first=batch_first)
        # self.linear2 = torch.nn.Linear(self.swin_transformer.config.hidden_size, d_model)
        self.layernorm = torch.nn.LayerNorm(d_model)

        self.fc1 = torch.nn.Linear(d_model, d_ff)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(d_ff, d_model)
        self.relu2 = torch.nn.ReLU()
    
    def forward(self, dec_input, enc_output, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # ### Self attention
        # swin_transformer_ = self.swin_transformer(dec_input).last_hidden_state
        # print(swin_transformer_.shape)
        # linear2_ = self.linear2(swin_transformer_)

        ### Cross attention
        attn_, attn_weight = self.attn(query=dec_input, key=enc_output, value=enc_output)
        layernorm_ = self.layernorm(dec_input + attn_)

        ### Feed forward
        relu1_ = self.relu1(self.fc1(layernorm_))
        relu2_ = self.relu2(self.fc2(relu1_))
        ff = layernorm_ + relu2_

        return ff


In [None]:
class MultimodalTransformer(torch.nn.Module):
    def __init__(self, max_seq_len, d_model, dropout, nhead, d_ff, num_layers, swin_transformer):
        super().__init__()
        # Encoder
        self.enc_mask = Mask()
        self.linear_embedding = torch.nn.Linear(1, d_model)
        self.month_embedding = torch.nn.Embedding(num_embeddings=12, embedding_dim=d_model)
        self.day_embedding = torch.nn.Embedding(num_embeddings=31, embedding_dim=d_model)
        self.dayofweek_embedding = torch.nn.Embedding(num_embeddings=7, embedding_dim=d_model)

        # self.linear1 = torch.nn.Linear(d_model*4, d_model)
        self.linear1 = torch.nn.Linear(d_model*2, d_model)
        self.enc_pos_encoding = PositionalEncoding(max_seq_len, d_model, dropout)
        self.encoder = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model, nhead, d_ff, dropout, batch_first=True), num_layers)
        
        # Decoder
        self.swin_transformer = swin_transformer
        self.linear2 = torch.nn.Linear(self.swin_transformer.config.hidden_size, d_model)
        self.decoder = torch.nn.TransformerDecoder(torch.nn.TransformerDecoderLayer(d_model, nhead, d_ff, dropout, batch_first=True), num_layers)
        # self.decoder = torch.nn.TransformerDecoder(MultimodalTransformerDecoderLayer(d_model, nhead, d_ff, dropout, batch_first=True), num_layers)
        self.flatten = torch.nn.Flatten()
        self.linear3 = torch.nn.Linear(d_model*49, d_model)
        self.linear4 = torch.nn.Linear(d_model, predict_length)

    def forward(self, enc_input, dec_input):
        # Encoding
        sales, month, day, dayofweek = enc_input
        enc_padding_mask, _ = self.enc_mask(sales.squeeze())
        linear_embedding_ = self.linear_embedding(sales)
        month_embedding = self.month_embedding(month)
        # day_embedding = self.day_embedding(day)
        # dayofweek_embedding = self.dayofweek_embedding(dayofweek)

        # enc_input = torch.concat([linear_embedding_, month_embedding, day_embedding, dayofweek_embedding], axis=-1)
        enc_input = torch.concat([linear_embedding_, month_embedding], axis=-1)
        enc_input = self.relu1 = torch.nn.ReLU()(self.linear1(enc_input))

        enc_pos_encoding_ = self.enc_pos_encoding(enc_input)
        enc_output_ = self.encoder(enc_pos_encoding_)
        # enc_output_ = self.encoder(enc_pos_encoding_, src_key_padding_mask=enc_padding_mask)
        
        # Decoding
        swin_transformer_ = self.swin_transformer(dec_input).last_hidden_state
        dec_output_ = self.relu1 = torch.nn.ReLU()(self.linear2(swin_transformer_))

        dec_output = self.decoder(tgt=dec_output_, memory=enc_output_)

        # Final
        flatten_ = self.flatten(dec_output)
        linear3_ = self.relu1 = torch.nn.ReLU()(self.linear3(flatten_))
        linear4_ = self.relu1 = torch.nn.ReLU()(self.linear4(linear3_))
        
        return linear4_

In [None]:
class MultimodalTransformerFromDataset(BaseModelWithCovariates):
    def __init__(self, imgpath_encoder, predict_length, swin_transformer, window_size, d_model, dropout, nhead, d_ff, num_layers, 
                 static_categoricals, time_varying_categoricals_encoder, time_varying_categoricals_decoder, static_reals, 
                 time_varying_reals_encoder,  time_varying_reals_decoder, x_reals, x_categoricals, embedding_labels, embedding_paddings, 
                 categorical_groups, embedding_sizes, **kwargs):
        self.save_hyperparameters()
        super().__init__(**kwargs)

        self.imgpath_encoder = imgpath_encoder
        self.predict_length = predict_length
        self.network = MultimodalTransformer(window_size, d_model, dropout, nhead, d_ff, num_layers, swin_transformer)
        # self.network.to(device)
    
    def forward(self, data):
        # Gather time series data
        sales = data[0]["encoder_cont"][:, :, 0].unsqueeze(-1) # shape: (batch_size, window_size, 1)
        month = data[0]["encoder_cont"][:, :, 1].type(torch.int)
        day = data[0]["encoder_cont"][:, :, 2].type(torch.int)
        dayofweek = data[0]["encoder_cont"][:, :, 3].type(torch.int)
        y = data[1][0] # shape: (batch_size, predict_length)

        # Gather image data
        img_path = data[0]["groups"].squeeze() # Label encoded image_path → shape: (batch_size, ) 
        img_path = self.imgpath_encoder.inverse_transform(img_path) # The real image path e.g) 'HnM/images/068/0687169002.jpg' → shape: (batch_size, )

        # Process image data
        img_li = []
        transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]) # Transform image based on ImageNet standard

        for n, path in enumerate(img_path): # Iterate images
            img = transform(Image.open(path).convert("RGB")) # Transform an image
            img_li.append(img)
        img_tensor = torch.stack(img_li, dim=0) # Put all the images together
        
        # Prediction
        pred = self.network(
            (sales.to(device), month.to(device), day.to(device), dayofweek.to(device)), 
            img_tensor.to(device)
            )
        pred = self.transform_output(prediction=pred, target_scale=data[0]["target_scale"].to(device)) # Inverse transform the output
        
        return pred, y.to(device)

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

swin_transformer = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224") # Get pre-trained SwinTransformer
swin_transformer.to(device)

model = MultimodalTransformerFromDataset.from_dataset(
    train_dataset,
    predict_length=predict_length,
    swin_transformer=swin_transformer,
    window_size=window_size,
    d_model=d_model,
    dropout=dropout,
    nhead=nhead,
    d_ff=d_ff,
    num_layers=num_layers,
    imgpath_encoder= imgpath_encoder
    )
model.to(device); print()

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(





### Train

In [None]:
# optimizer = torch.optim.Adam(model.parameters(1e-5))
# loss_fn = torch.nn.MSELoss()
# loss_fn_ = torch.nn.MSELoss(reduction="none")
# train_loss_li, valid_loss_li = [], []


# def plot_loss(train_loss_li, valid_loss_li):
#     plt.plot(train_loss_li, label="train")
#     plt.plot(valid_loss_li, label="valid")
#     plt.title("loss")
#     plt.legend()

# def plot_bestsample(loss, pred, y, iter, msg):
#     loss = loss.mean(axis=1) # Shape: (batch_size, )
#     _, best_idx_li = torch.sort(loss)
#     for best_idx in best_idx_li:
#         # best_pred = torch.round(pred[best_idx])
#         best_pred = pred[best_idx]
#         best_pred = best_pred.cpu().detach().numpy() # Sales is always int → Shape: (predict_length, )
#         best_pred[best_pred < 0] = 0 # Sales never becomes negative
#         best_y = y[best_idx].cpu().detach().numpy()
#         if (np.max(best_y) < 10): # If predicted value is all 0, consider as not the best
#             continue
#         break
    
#     plt.plot(best_pred, label="pred")
#     plt.plot(best_y, label="y", color="gray", alpha=0.3)
#     plt.title(f"{iter}th iter: Best example amongst {msg} dataset")
#     plt.legend()

# def train():
#     total_train_loss, total_valid_loss = 0, 0
#     for n, train_data in enumerate(train_dataloader):
#         clear_output(wait=True)

#         # Train
#         model.train(True)
#         optimizer.zero_grad()
#         train_pred, train_y = model(train_data)

#         # Get train loss
#         train_loss = loss_fn(train_pred, train_y) # Shape: (batch_size, predict_length)
#         train_loss.backward()
#         train_loss_raw = loss_fn_(train_pred, train_y)
#         total_train_loss += train_loss.item()
#         train_loss_li.append(total_train_loss/(n+1))
#         optimizer.step()

#         # Validation
#         model.eval()
#         valid_data = next(iter(valid_dataloader))
#         valid_pred, valid_y = model(valid_data)

#         # Get validation loss
#         valid_loss = loss_fn(valid_pred, valid_y)
#         valid_loss_raw = loss_fn_(valid_pred, valid_y)
#         total_valid_loss += valid_loss.item()
#         valid_loss_li.append(total_valid_loss/(n+1))

#         # Plot
#         plt.figure(figsize=(18,5))
#         plt.subplot(1,3,1); plot_loss(train_loss_li, valid_loss_li)
#         plt.subplot(1,3,2); plot_bestsample(train_loss_raw, train_pred, train_y, n, "TRAIN")
#         plt.subplot(1,3,3); plot_bestsample(valid_loss_raw, valid_pred, valid_y, n, "VALID")
#         plt.show()

#         # Report
#         print(f"\r {n}/{len(train_dataloader)} → train_loss: {np.mean(train_loss_li)}, valid_loss: {np.mean(valid_loss_li)}", end="")
            
# for epoch in range(10):
#     mean_train_loss = train()

In [None]:
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.MSELoss()
loss_fn_ = torch.nn.MSELoss(reduction="none")
train_loss_li, valid_loss_li = [], []
temp = None

def plot_loss(train_loss_li, valid_loss_li):
    plt.plot(train_loss_li, label="train")
    plt.plot(valid_loss_li, label="valid")
    plt.title("Loss")
    plt.legend()

def plot_bestsample(loss, pred, y, iter, msg):
    loss = loss.mean(axis=1) # Shape: (batch_size, )
    _, best_idx_li = torch.sort(loss)
    for best_idx in best_idx_li:
        # best_pred = torch.round(pred[best_idx])
        best_pred = pred[best_idx]
        best_pred = best_pred.cpu().detach().numpy() # Sales is always int → Shape: (predict_length, )
        best_pred[best_pred < 0] = 0 # Sales never becomes negative
        best_y = y[best_idx].cpu().detach().numpy()
        if (np.max(best_y) < 10): # If predicted value is all 0, consider as not the best
            continue
        break
    
    plt.plot(best_pred, label="pred")
    plt.plot(best_y, label="y", color="gray", alpha=0.3)
    plt.title(f"{iter}th iter: Best example amongst {msg} dataset")
    plt.legend()

def train():
    global temp
    total_train_loss = 0
    # for n, (train_data, valid_data) in enumerate(zip(train_dataloader, valid_dataloader)):
    for n, train_data in enumerate(train_dataloader):
        # Train
        model.train(True)
        optimizer.zero_grad()
        train_pred, train_y = model(train_data)

        # Get train loss
        train_loss = loss_fn(train_pred, train_y) # Shape: (batch_size, predict_length)
        train_loss.backward()
        total_train_loss += train_loss.item()
        optimizer.step()

        # Report
        print(f"\r {n}/{len(train_dataloader)} → train_loss: {total_train_loss / (n+1)}", end="")
    return total_train_loss / len(train_dataloader)

def val():
    total_valid_loss = 0
    # for n, (train_data, valid_data) in enumerate(zip(train_dataloader, valid_dataloader)):
    for n, valid_data in enumerate(valid_dataloader):
        model.eval()
        # Predict
        valid_pred, valid_y = model(valid_data)

        # Get train loss
        valid_loss = loss_fn(valid_pred, valid_y) # Shape: (batch_size, predict_length)
        total_valid_loss += valid_loss.item()
        optimizer.step()

        # Report
        print(f"\r {n}/{len(train_dataloader)} → valid_loss: {total_valid_loss / (n+1)}", end="")
    return total_valid_loss / len(train_dataloader)

epoch = 10
for e in range(epoch):
    mean_train_loss = train(); train_loss_li.append(mean_train_loss)
    mean_valid_loss = val(); valid_loss_li.append(mean_valid_loss)
    print(f"\r {e}/{epoch} → train_loss: {np.mean(mean_train_loss)}, valid_loss: {np.mean(mean_valid_loss)}", end="")
    clear_output(wait=True)
    plot_loss(train_loss_li, valid_loss_li)
    plt.show()

 46/117 → train_loss: 31186.310983211435

KeyboardInterrupt: 