In [2]:
_ = """
Desciption
- Visuelle products
- Not use of pytorch_forecasting
- year, month, day, dayofweek as positional embedding feature
- Extract attention from MY transformer but not from image model's
"""

# Configurations

### Import

In [3]:
import os
import joblib
import random
from IPython.display import clear_output
import copy
from tqdm import tqdm

import numpy as np
import pandas as pd; pd.set_option("display.max_columns", None)
import matplotlib.pyplot as plt

from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from tslearn.utils import to_time_series_dataset

import torch
from PIL import Image
from torchvision import transforms
from transformers import SwinModel, SwinConfig, ViTModel, ViTConfig, Mask2FormerModel, Mask2FormerConfig
import matplotlib.cm as cm
import cv2
import pytorch_forecasting as pf
from pytorch_forecasting.models.base_model import BaseModelWithCovariates

device = torch.device("cuda")

Install h5py to use hdf5 features: http://docs.h5py.org/
  warn(h5py_msg)
  from .autonotebook import tqdm as notebook_tqdm


### Params

In [4]:
# Sampling parameter
n_smaples = None

# Data parameter
random_state = 0
encoder_len = 365
pred_len = 30
batch_size = 128
valid_start_date = "2020-01-22"

# Model hyperparameter
d_model = 128; d_model = 256; d_model = 512
nhead = 4; nhead = 8
d_ff = 256; d_ff = 512; d_ff = 1024
dropout = 0.3; dropout = 0.3
num_layers = 4; num_layers = 6

# Seed set
random.seed(random_state)
np.random.seed(random_state)
torch.manual_seed(random_state)

<torch._C.Generator at 0x7f34ca0e2730>

# Data

### Read

In [4]:
# Read transaction
df_trans = pd.read_csv("../HnM/transactions_train.csv", parse_dates=["t_dat"], dtype={"article_id":str})
df_train_raw = df_trans[df_trans["t_dat"] < valid_start_date]
df_valid_raw = df_trans[df_trans["t_dat"] >= valid_start_date]

### Preprocess

In [5]:
def preprocess(data, is_train=True):
    data = data.copy()

    # Make sales
    data = data.groupby(["article_id", "t_dat"], as_index=False).agg(sales=("customer_id","size"), price=("price","mean"))
    
    # Expand date
    def func(x):
        date_min = x["t_dat"].min()
        date_max = x["t_dat"].max()

        date_ref = pd.DataFrame(pd.date_range(date_min, date_max, freq="d"), columns=["t_dat"])
        x = pd.merge(date_ref, x, on="t_dat", how="left")
        return x

    data = data.groupby("article_id", as_index=False).apply(lambda x: func(x)).reset_index(drop=True)
    data["sales"] = data["sales"].fillna(0)
    data["price"] = data["price"].fillna(method="ffill")
    data["article_id"] = data["article_id"].fillna(method="ffill")
    data["price"] = data["price"].fillna(method="ffill")

    # Generate temporal information
    data["year"] = data["t_dat"].dt.year
    data["month"] = data["t_dat"].dt.month
    data["week"] = data["t_dat"].dt.isocalendar().week
    data.loc[(data["month"] == 12) & (data["week"] < 10), "year"] = data["year"] + 1
    data.loc[(data["month"] == 1) & (data["week"] > 10), "year"] = data["year"] - 1
    data["dayofweek"] = data["t_dat"].dt.dayofweek
    data["day"] = data["t_dat"].dt.day
    data["month"] -= 1
    data["week"] -= 1
    data["day"] -= 1
    
    # Aggregate with list
    data = data.groupby("article_id", as_index=False)[["sales", "price", "year", "month", "week", "day", "dayofweek"]].agg(list)
    data["size"] = data["sales"].str.len()

    # Filter 
    if is_train:
        data = data[data["size"] > encoder_len + pred_len]
    else: 
        data = data[data["size"] > pred_len + 1]

    return data

df_train_prep = preprocess(df_train_raw)
df_valid_prep = preprocess(df_valid_raw, is_train=False)

df_train_prep.to_parquet("df_train_prep.parquet")
df_valid_prep.to_parquet("df_valid_prep.parquet")

### Dataset

In [5]:
df_train_prep = pd.read_parquet("df_train_prep.parquet")
df_valid_prep = pd.read_parquet("df_valid_prep.parquet")

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.sales_tensor = self.get_windowed_tensor(data["sales"].values)
        self.year_tensor = self.get_windowed_tensor(data["year"].values)
        self.month_tensor = self.get_windowed_tensor(data["month"].values)
        self.week_tensor = self.get_windowed_tensor(data["week"].values)
        self.day_tensor = self.get_windowed_tensor(data["day"].values)
        self.dayofweek_tensor = self.get_windowed_tensor(data["dayofweek"].values)
    
    def get_windowed_tensor(self, data):
        windowed_tensor_li = []
        for d in data:
            d = d.astype(float)
            d = torch.FloatTensor(d).unfold(0, encoder_len+pred_len, 1)
            windowed_tensor_li.append(d)
        windowed_tensor = torch.concat(windowed_tensor_li)
        return windowed_tensor

    def __len__(self):
        return self.sales_tensor.shape[0]
    
    def __getitem__(self, idx):
        sales = np.log1p(self.sales_tensor[idx][:-pred_len]).unsqueeze(-1)
        y = np.log1p(self.sales_tensor[idx][-pred_len:])
        
        year = self.year_tensor[idx][:-pred_len].unsqueeze(-1)
        year_y = self.year_tensor[idx][-pred_len:].unsqueeze(-1)

        month = self.month_tensor[idx][:-pred_len].type(torch.IntTensor)
        month_y = self.month_tensor[idx][-pred_len:].type(torch.IntTensor)

        week = self.week_tensor[idx][:-pred_len].type(torch.IntTensor)
        week_y = self.week_tensor[idx][-pred_len:].type(torch.IntTensor)

        day = self.day_tensor[idx][:-pred_len].type(torch.IntTensor)
        day_y = self.day_tensor[idx][-pred_len:].type(torch.IntTensor)

        dayofweek = self.dayofweek_tensor[idx][:-pred_len].type(torch.IntTensor)
        dayofweek_y = self.dayofweek_tensor[idx][-pred_len:].type(torch.IntTensor)

        
        return sales, year, month, week, day, dayofweek, y, year_y, month_y, week_y, day_y, dayofweek_y

train_dataset = TrainDataset(df_train_prep)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, prefetch_factor=8)
for data in train_dataloader:
    sales, year, month, week, day, dayofweek, y, year_y, month_y, week_y, day_y, dayofweek_y = data
    print("sales:", sales.shape)
    print("year:", year.shape)
    print("month:", month.shape)
    print("week:", week.shape)
    print("day:", day.shape)
    print("dayofweek:", dayofweek.shape)
    print("_"*100)
    print("y:", y.shape)
    print("year_y:", year_y.shape)
    print("month_y:", month_y.shape)
    print("week_y:", week_y.shape)
    print("day_y:", day_y.shape)
    print("dayofweek_y:", dayofweek_y.shape)
    break

sales: torch.Size([128, 365, 1])
year: torch.Size([128, 365, 1])
month: torch.Size([128, 365])
week: torch.Size([128, 365])
day: torch.Size([128, 365])
dayofweek: torch.Size([128, 365])
____________________________________________________________________________________________________
y: torch.Size([128, 30])
year_y: torch.Size([128, 30, 1])
month_y: torch.Size([128, 30])
week_y: torch.Size([128, 30])
day_y: torch.Size([128, 30])
dayofweek_y: torch.Size([128, 30])


# Architecture

In [8]:
class PositionalEmbedder(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # Raw embedder
        self.year_embedder = torch.nn.Linear(1, d_model)
        self.month_embedder = torch.nn.Embedding(num_embeddings=12, embedding_dim=d_model)
        self.week_embedder = torch.nn.Embedding(num_embeddings=52, embedding_dim=d_model)
        self.day_embedder = torch.nn.Embedding(num_embeddings=365, embedding_dim=d_model)
        self.dayofweek_embedder = torch.nn.Embedding(num_embeddings=7, embedding_dim=d_model)
        
        # Concatenate
        self.layer_norm1 = torch.nn.LayerNorm(d_model*5)
        self.linear1 = torch.nn.Linear(d_model*5, d_model)

        # Feed forward concatenated positions
        self.linear2 = torch.nn.Linear(d_model, d_model)
        self.activation1 = torch.nn.ELU()
        self.linear3 = torch.nn.Linear(d_model, d_model)
        self.dropout = torch.nn.Dropout(dropout)

        # Residual
        self.layer_norm2 = torch.nn.LayerNorm(d_model)
    
    def forward(self, year, month, week, dayofweek, day):
        year_embedding = self.year_embedder(year)
        month_embedding = self.month_embedder(month)
        week_embedding = self.week_embedder(week)
        day_embedding = self.day_embedder(day)
        dayofweek_embedding = self.day_embedder(dayofweek)

        concat = self.linear1(self.layer_norm1(torch.concat([year_embedding, month_embedding, week_embedding, day_embedding, dayofweek_embedding], dim=-1)))
        feed_forward = self.dropout(self.linear3(self.activation1(self.linear2(concat))))
        residual = self.layer_norm2(concat + feed_forward)
        return residual

class EncoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout, d_ff):
        super().__init__()
        # Self attention
        self.mha = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.dropout1 = torch.nn.Dropout(dropout)

        # Add and norm
        self.layer_norm1 = torch.nn.LayerNorm(d_model)

        # Feed forward
        self.linear1 = torch.nn.Linear(d_model, d_ff)
        self.activation = torch.nn.ELU()
        self.dropout2 = torch.nn.Dropout(dropout)
        self.linear2 = torch.nn.Linear(d_ff, d_model)
        self.dropout3 = torch.nn.Dropout(dropout)

        # Add and norm
        self.layer_norm2 = torch.nn.LayerNorm(d_model)

    def forward(self, x, positional_embedding):
        # Self attention
        self_attn = self.mha(
                    query = x + positional_embedding,
                    key = x + positional_embedding,
                    value = x,
                    need_weights = False
                    )[0]
        self_attn = self.dropout1(self_attn)

        # Add and norm
        add_norm1 = self.layer_norm1(x + self_attn)
        
        # Feed forward
        feed_forward = self.linear2(self.dropout2(self.activation(self.linear1(add_norm1))))
        feed_forward = self.dropout3(feed_forward)

        # Add and norm
        add_norm2 = self.layer_norm2(add_norm1 + feed_forward)

        return add_norm2

class Encoder(torch.nn.Module):
    def __init__(self, encoder_block, num_layers):
        super().__init__()
        # self.layers = [torch.nn.DataParallel(copy.deepcopy(encoder_block).to(device)) for _ in range(num_layers)]
        self.layers = [torch.nn.DataParallel(copy.deepcopy(encoder_block).cuda()) for _ in range(num_layers)]
    
    def forward(self, x, positional_embedding):
        output = x
        for layer in self.layers:
            output = layer(output, positional_embedding)
        return output

class HistDecoderLayer(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        # Self attention
        self.self_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.dropout1 = torch.nn.Dropout(dropout)

        # Add and norm
        self.layer_norm1 = torch.nn.LayerNorm(d_model)

        # Cross attention
        self.cross_attn = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.dropout2 = torch.nn.Dropout(dropout)

        # Add and norm
        self.layer_norm2 = torch.nn.LayerNorm(d_model)

        # Feed forward
        self.linear1 = torch.nn.Linear(d_model, d_ff)
        self.activation = torch.nn.ELU()
        self.dropout3 = torch.nn.Dropout(dropout)
        self.linear2 = torch.nn.Linear(d_ff, d_model)
        self.dropout4 = torch.nn.Dropout(dropout)

        # Add and norm
        self.layer_norm3 = torch.nn.LayerNorm(d_model)
    
    def forward(self, decoder_input, dec_positional_embedding, encoder_output, enc_positional_embedding):
        # Self attention
        self_attn = self.self_attn(
                    query = decoder_input + dec_positional_embedding,
                    key = decoder_input + dec_positional_embedding,
                    value = decoder_input,
                    need_weights = False
                    )[0]
        self_attn = self.dropout1(self_attn)
        
        # Add and norm
        layer_norm1 = self.layer_norm1(self_attn + decoder_input)

        # Cross attention
        cross_attn, attn_weight = self.cross_attn(
            query = layer_norm1 + dec_positional_embedding,
            key = encoder_output + enc_positional_embedding,
            value = encoder_output,
            need_weights = True
        )
        cross_attn = self.dropout2(cross_attn)

        # Add and norm
        layer_norm2 = self.layer_norm2(cross_attn + layer_norm1)

        # Feed forward
        feed_forward = self.linear2(self.dropout3(self.activation(self.linear1(layer_norm2))))
        feed_forward = self.dropout4(feed_forward)
        
        # Add and norm
        layer_norm3 = self.layer_norm3(feed_forward + layer_norm2)

        return layer_norm3, attn_weight

class Decoder(torch.nn.Module):
    def __init__(self, decoder_block, num_layers):
        super().__init__()
        # self.layers = [copy.deepcopy(decoder_block).to(device) for _ in range(num_layers)]
        self.layers = [copy.deepcopy(decoder_block).cuda() for _ in range(num_layers)]
    
    def forward(self, dec_positional_embedding, encoder_output, enc_positional_embedding):
        output = dec_positional_embedding
        for layer in self.layers:
            output, attn_weight = layer(output, dec_positional_embedding, encoder_output, enc_positional_embedding)
        return output, attn_weight

class FeedForward(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(d_model, d_ff)
        self.activation = torch.nn.ELU()
        self.dropout1 = torch.nn.Dropout(dropout)
        self.linear2 = torch.nn.Linear(d_ff, 1)
        self.dropout2 = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout2(self.linear2(self.dropout1(self.activation(self.linear1(x)))))

class Transformer(torch.nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        # Encoder
        self.enc_positional_embedder = PositionalEmbedder(d_model)
        self.sales_embedder = torch.nn.Linear(1, d_model)
        self.encoder = Encoder(EncoderLayer(d_model, nhead, dropout, d_ff), num_layers)
        self.decoder = Decoder(HistDecoderLayer(d_model, nhead, dropout), num_layers)

        # Decoder
        self.dec_positional_embedder = PositionalEmbedder(d_model)

        # Output
        self.ffn = FeedForward()
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, sales, year, month, week, day, dayofweek, year_y, month_y, week_y, day_y, dayofweek_y):
        # Encoder
        enc_positional_embedding = self.enc_positional_embedder(year, month, week, dayofweek, day)
        sales = self.sales_embedder(sales)
        encoder_output = self.encoder(sales, enc_positional_embedding)

        # Decoder
        dec_positional_embedding = self.dec_positional_embedder(year_y, month_y, week_y, dayofweek_y, day_y)
        dec_output, atten_weight = self.decoder(dec_positional_embedding, encoder_output, enc_positional_embedding)

        output = self.ffn(dec_output)
        sigmoid_output = self.sigmoid(output)
        
        return sigmoid_output, output, atten_weight

model = Transformer(d_model, nhead, dropout)
model = torch.nn.DataParallel(model)
# model.to(device)
model.cuda()

DataParallel(
  (module): Transformer(
    (enc_positional_embedder): PositionalEmbedder(
      (year_embedder): Linear(in_features=1, out_features=512, bias=True)
      (month_embedder): Embedding(12, 512)
      (week_embedder): Embedding(52, 512)
      (day_embedder): Embedding(365, 512)
      (dayofweek_embedder): Embedding(7, 512)
      (layer_norm1): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=2560, out_features=512, bias=True)
      (linear2): Linear(in_features=512, out_features=512, bias=True)
      (activation1): ELU(alpha=1.0)
      (linear3): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.3, inplace=False)
      (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (sales_embedder): Linear(in_features=1, out_features=512, bias=True)
    (encoder): Encoder()
    (decoder): Decoder()
    (dec_positional_embedder): PositionalEmbedder(
      (year_embedder): Linear(in_featu

: 

# Train

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
binary_loss_fn_ = torch.nn.BCELoss()
mse_loss_fn_ = torch.nn.MSELoss()
temp = None

def binary_loss_fn(pred, y):
    # y = torch.where(y==0, 0, 1).type(torch.FloatTensor).to(device)
    y = torch.where(y==0, 0, 1).type(torch.FloatTensor).cuda()
    loss = binary_loss_fn_(pred, y)
    return loss

def mse_loss_fn(pred, y):
    # mask = torch.where(y==0, 0, 1).type(torch.FloatTensor).to(device)
    mask = torch.where(y==0, 0, 1).type(torch.FloatTensor).cuda()
    pred = pred * mask
    # y = y.to(device) * mask
    y = y.cuda() * mask
    loss = mse_loss_fn_(pred, y)
    return loss

def plot_sample(y, pred, total_loss_li, binary_loss_li, mse_loss_li):
    # Plot sample
    clear_output(wait=True)
    y_sample = y[-1]
    pred_sample = pred[-1].detach().cpu()
    
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.plot(total_loss_li, label="Total", linewidth=3)
    plt.twinx()
    plt.plot(binary_loss_li, label="Binary", color="orange")
    plt.plot(mse_loss_li, label="MSE", color="green")
    
    plt.title("Loss")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(np.expm1(y_sample), label="y")
    plt.plot(np.expm1(pred_sample), label="pred")
    plt.title("Sample")
    plt.legend()
    plt.show()

def train():
    model.train()
    total_loss_li, binary_loss_li, mse_loss_li = [], [], []
    mean_total_loss_li, mean_binary_loss_li, mean_mse_loss_li = [], [], []
    pbar = tqdm(train_dataloader)

    for n, data in enumerate(pbar):
        sales, year, month, week, day, dayofweek, y, year_y, month_y, week_y, day_y, dayofweek_y = data
        
        # Train
        optimizer.zero_grad()
        # sigmoid_pred, pred, atten_weight = model(sales.to(device), 
        #                             year.to(device), 
        #                             month.to(device), 
        #                             week.to(device), 
        #                             day.to(device), 
        #                             dayofweek.to(device), 
        #                             year_y.to(device), 
        #                             month_y.to(device), 
        #                             week_y.to(device), 
        #                             day_y.to(device), 
        #                             dayofweek_y.to(device))
        sigmoid_pred, pred, atten_weight = model(sales.cuda(), 
                                    year.cuda(), 
                                    month.cuda(), 
                                    week.cuda(), 
                                    day.cuda(), 
                                    dayofweek.cuda(), 
                                    year_y.cuda(), 
                                    month_y.cuda(), 
                                    week_y.cuda(), 
                                    day_y.cuda(), 
                                    dayofweek_y.cuda())
        sigmoid_pred = sigmoid_pred.squeeze(-1)
        pred = pred.squeeze(-1)

        binary_loss = binary_loss_fn(sigmoid_pred, y)
        mse_loss = mse_loss_fn(pred, y)
        loss = binary_loss + mse_loss
        loss.backward()
        optimizer.step()

        # Report
        total_loss_li.append(loss.item())
        binary_loss_li.append(binary_loss.item())
        mse_loss_li.append(mse_loss.item())
        pbar.set_description(f"{np.round(sum(total_loss_li[-100:])/100, 5)} → binary: {np.round(sum(binary_loss_li[-100:])/100, 5)}, mse: {np.round(sum(mse_loss_li[-100:])/100, 5)}")
        if n % 10 == 0:
            mean_total_loss_li.append(sum(total_loss_li[-100:])/100)
            mean_binary_loss_li.append(sum(binary_loss_li[-100:])/100)
            mean_mse_loss_li.append(sum(mse_loss_li[-100:])/100)
            
            calc_pred = torch.where(sigmoid_pred>0.5, 1, 0) * pred
            plot_sample(y, calc_pred, mean_total_loss_li[-10:], mean_binary_loss_li[-10:], mean_mse_loss_li[-10:])


epoch = 5
for e in range(epoch):
    train()
    break

  0%|          | 0/4675 [00:01<?, ?it/s]


RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1473937/1340636551.py", line 199, in forward
    dec_output, atten_weight = self.decoder(dec_positional_embedding, encoder_output, enc_positional_embedding)
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1473937/1340636551.py", line 160, in forward
    output, attn_weight = layer(output, dec_positional_embedding, encoder_output, enc_positional_embedding)
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1473937/1340636551.py", line 119, in forward
    self_attn = self.self_attn(
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 1205, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/functional.py", line 5224, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "/home/sh-sungho.park/anaconda3/envs/cudatest/lib/python3.8/site-packages/torch/nn/functional.py", line 4787, in _in_projection_packed
    return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)


# Eval

In [None]:
model.eval()

for data in train_dataloader:
    sales, year, month, week, day, dayofweek, y, year_y, month_y, week_y, day_y, dayofweek_y = data
    
    # Train
    with torch.no_grad():
        sigmoid_pred, pred, atten_weight = model(sales.to(device), 
                                    year.to(device), 
                                    month.to(device), 
                                    week.to(device), 
                                    day.to(device), 
                                    dayofweek.to(device), 
                                    year_y.to(device), 
                                    month_y.to(device), 
                                    week_y.to(device), 
                                    day_y.to(device), 
                                    dayofweek_y.to(device))
        sigmoid_pred = sigmoid_pred.squeeze(-1)
        pred = pred.squeeze(-1)
    break

In [1]:
import seaborn as sns

arr = atten_weight[10].cpu().numpy()
sns.heatmap(arr)

NameError: name 'atten_weight' is not defined