In [123]:
import argparse
import os
import torch

from models.tcn import TCN
from exp.exp_informer import Exp_Informer
from models.dtcn_encoder import *
from models.tcn_encoder import *
from models.lstm_encoder import LSTM_Encoder
from models.informer_encoder import Informer_Encoder
from pyats.datastructures import AttrDict

import seaborn as sns
import matplotlib.pyplot as plt

In [124]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [125]:
device

'cuda'

In [126]:
def _get_data(self, flag):
    args = self.args

    data_dict = {
        'ETTh1':Dataset_ETT_hour,
        'ETTh2':Dataset_ETT_hour,
        'ETTm1':Dataset_ETT_minute,
        'ETTm2':Dataset_ETT_minute,
        'WTH':Dataset_Custom,
        'ECL':Dataset_Custom,
        'Solar':Dataset_Custom,
        'custom':Dataset_Custom,
    }
    Data = data_dict[self.args.data]
    timeenc = 0 if args.embed!='timeF' else 1

    if flag == 'test':
        shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freq
    elif flag=='pred':
        shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq
    data_set = Data(
        root_path=args.root_path,
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        inverse=args.inverse,
        timeenc=timeenc,
        freq=freq,
        cols=args.cols
    )
    print(flag, len(data_set))
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)

    return data_set, data_loader

In [180]:
args = AttrDict({ "model": 'lstm-moco',
                 "data": 'ETTh1', "data_path": 'ETTh1.csv', "pred_len": 24, "label_len": 24,
                 "root_path": './data/ETT', "cos_lr": True, "loss_lambda": 0.5,
        
        "batch_size": 32,  "checkpoints": './checkpoints/', "d_model": 320, 
        "e_layers": 5,  "features": 'M', "freq": 'h', "embed": 'timeF',
        "loss": 'mse', "dropout": 0.1, "kernel_size": 3, "l2norm": True,
#         "c_out": 321, "enc_in": 321,
                 
        
        "seq_len": 48, "target": 'OT', "train_epochs": 1,  "mask_rate": 0.3, 
                 
        "learning_rate": 0.001, "patience": 3,
        "gpu": 0, 
          
        "moco_average_pool": False, "data_aug": "cost", "mare": False, "time_feature_embed": False, "inverse": False,
        "cols": None, "num_workers": 0,  "closs_decay": False, 
        "use_gpu": False, "use_multi_gpu": False, "des": "Exp"})

In [181]:
data_parser = {
    'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},
    'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},
    'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},
}
if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]

# args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ','').split(',')]
args.detail_freq = args.freq
args.freq = args.freq[-1:]
args.time_feature_embed = False
args.mare = False
print('Args in experiment:')
print(args)

Exp = Exp_Informer
exp = Exp(args)

Args in experiment:
AttrDict({'model': 'lstm-moco', 'data': 'ETTh1', 'data_path': 'ETTh1.csv', 'pred_len': 24, 'label_len': 24, 'root_path': './data/ETT', 'cos_lr': True, 'loss_lambda': 0.5, 'batch_size': 32, 'checkpoints': './checkpoints/', 'd_model': 320, 'e_layers': 5, 'features': 'M', 'freq': 'h', 'embed': 'timeF', 'loss': 'mse', 'dropout': 0.1, 'kernel_size': 3, 'l2norm': True, 'seq_len': 48, 'target': 'OT', 'train_epochs': 1, 'mask_rate': 0.3, 'learning_rate': 0.001, 'patience': 3, 'gpu': 0, 'moco_average_pool': False, 'data_aug': 'cost', 'mare': False, 'time_feature_embed': False, 'inverse': False, 'cols': None, 'num_workers': 0, 'closs_decay': False, 'use_gpu': False, 'use_multi_gpu': False, 'des': 'Exp', 'enc_in': 7, 'dec_in': 7, 'c_out': 7, 'detail_freq': 'h'})
Use CPU
Use Data augmentation method: cost
l2norm True
Mask Rate: 0.5
[INFO] NOT Using Time Features.
[INFO] Number of parameters:  971399
TCN_MoCo(
  (encoder_q): TCNBase(
    (enc_embedding): Linear(in_features=7, ou

In [182]:
if args.model=='lstm' or args.model=='tcn' or args.model=='dtcn' or args.model == 'informer-encoder':
    model_path = './checkpoints/contrast-{}_{}_contrastL{}_mask{}_l2norm{}_ft{}_sl{}_pl{}_dm{}_el{}_{}_{}/checkpoint.pth'.format(
                args.model, args.data, args.loss_lambda, 
                args.mask_rate, str(args.l2norm), args.features, 
                args.seq_len, args.pred_len,
                args.d_model, args.e_layers, args.des, 0)
elif "moco" in args.model or args.model=="cost-e2e":
    model_path = './checkpoints/contrast-{}_{}_contrastL{}_mask{}_l2norm{}_ft{}_timeF{}_mare{}_cldecay{}_sl{}_pl{}_dm{}_el{}_avg{}_cos{}_aug{}_{}_{}/checkpoint.pth'.format(
                    args.model, args.data, args.loss_lambda, args.mask_rate, str(args.l2norm), args.features, 
                    str(args.time_feature_embed), str(args.mare), str(args.closs_decay), args.seq_len, args.pred_len,
                    args.d_model, args.e_layers, str(args.moco_average_pool), str(args.cos_lr), 
                    args.data_aug, args.des, 0)

checkpoint = torch.load(
    model_path,
    map_location=torch.device('cpu'))


In [183]:
checkpoint = {k.replace("tcn", "encoder"): v for k, v in checkpoint.items()}

In [184]:
exp.model.load_state_dict(checkpoint)


<All keys matched successfully>

In [185]:
exp.model = exp.model.to(device)
exp.model.eval()
print("done")

done


In [133]:

vali_data, vali_loader = exp._get_data(flag = 'val')

val 2857


In [134]:
len(vali_loader)

89

In [135]:
vali_loader

<torch.utils.data.dataloader.DataLoader at 0x152394158fa0>

In [136]:
criterion =  nn.MSELoss()

In [137]:
def sqence_line(sp, dim):
    input_x = sp["input"].detach().numpy()
    target_label = sp["label"][:sp["step"]+1, :].numpy()
    
    predict = float(sp["pred"][sp["step"], dim])
    
    concat_ts = np.concatenate((input_x, target_label), axis=0)
    concat_ts = list(concat_ts[:, dim])

    x = [ind for ind, y in enumerate(concat_ts)]
    y = [y for ind, y in enumerate(concat_ts)]
    
    x.append(len(concat_ts)-1)
    y.append(predict)
    # y2point = input_x[:, ]

    input_end = args.seq_len - args.pred_len + sp["step"] + 1
    input_unseen_end = args.seq_len + sp["step"] + 1

    plt.figure(figsize=(15, 5))
    
    dim_grad = torch.nn.functional.softmax(sp['gradient'][:, dim].unsqueeze(0), dim=1).numpy()
    dim_grad = np.repeat(dim_grad, 2, axis=0)
    plt.imshow(dim_grad, cmap='GnBu', alpha=1,  origin="upper", interpolation='nearest', aspect='auto', extent=(-0.5, input_end-0.5, min(y)-0.1, max(y)+0.1))
#     np.pad(dim_grad,((0, 0), (0, 23)), 'constant')
    plt.plot(x[:input_end], y[:input_end], "ko-", label= "Input")

    plt.plot(x[input_end-1: input_unseen_end], y[input_end-1:input_unseen_end], color="grey", linestyle=":",  marker = "o", linewidth=1, markersize=3, label= "Unseen")
    
    plt.plot(x[input_end-1], y[input_end-1], "ko-")
    
    plt.plot(x[input_unseen_end-1:-1], y[input_unseen_end-1:-1], "ro", markersize=10, label= "Label")
    plt.plot(x[-1], y[-1], color="orange", marker = "s", linestyle = 'None', markersize=8, label= "Prediction")
    
    plt.xticks([x for x in range(0, len(x), 2)])
    plt.ylim(ymax = max(y)+0.1, ymin = min(y)-0.1)
    
    plt.grid(color='grey', linestyle='--', linewidth=0.5)
#     plt.yticks([round(max(y) + 0.5 * y, 2) for y in range(int((max(y)-min(y))/0.5)+1)])

    plt.legend(loc='best') #lower right
    plt.show()

### Average over batch and dimensions

In [173]:
batch_grad = []
for ind, batch in enumerate(vali_loader):
    print(ind, batch[0].shape)
    exp.model.eval()
    seq_x, seq_y, seq_x_mark, seq_y_mark = map(lambda x: x.float().to(device), batch)
#     seq_x = seq_x.float()
#     seq_y = seq_y.float()
#     seq_x_mark = seq_x_mark.float()
    exp.model.train()
    seq_x.requires_grad = True
    
    if 'moco' in args.model or "cost" in args.model:
        prediction_y, _, contrast_loss  = exp.model(seq_x, seq_x_mark)
    else:
        prediction_y, _ , _ = exp.model(seq_x)

    seq_y = seq_y[:,-args.pred_len:, :]

    buff = []
    for i in range(args.pred_len):
        pred = prediction_y[:, i, :]
        label = seq_y[:, i, :]

        loss = criterion(pred, label)
    #     print(loss)

        exp.model.zero_grad()
        
        if i == args.pred_len-1:
            loss.backward(retain_graph=False)
        else:
            loss.backward(retain_graph=True)
            
        data_grad = seq_x.grad.detach()

        end = args.seq_len #- args.pred_len + i + 1
        begin = 0 #end-args.pred_len
        buff.append(torch.abs(data_grad[:,begin: end]))
    #     print(data_grad.shape)
    buff = [torch.sum(torch.sum(x, 0), -1).unsqueeze(0) for x in buff]
    
    batch_grad.append(buff)
    
    if ind == 10:
        break


0 torch.Size([32, 48, 7])
1 torch.Size([32, 48, 7])
2 torch.Size([32, 48, 7])
3 torch.Size([32, 48, 7])
4 torch.Size([32, 48, 7])
5 torch.Size([32, 48, 7])
6 torch.Size([32, 48, 7])
7 torch.Size([32, 48, 7])
8 torch.Size([32, 48, 7])
9 torch.Size([32, 48, 7])
10 torch.Size([32, 48, 7])


In [None]:
cmap = sns.cm.rocket_r
for i in range(args.pred_len):
    print(i)
    temp = []
    for x in batch_grad:
        temp.append(x[i])
        
    conbine = torch.cat(temp, dim=0)
    avg_all = torch.sum(conbine, 0).unsqueeze(0).numpy()
#     avg_all = torch.nn.functional.softmax(torch.sum(conbine, 0).unsqueeze(0), dim=1).numpy()

    plt.figure(figsize=(5, 1))
    sns.heatmap(avg_all, cmap = cmap)

In [None]:
seq_x.shape

In [None]:
cmap = sns.cm.rocket_r
for i in range(args.pred_len):
    print(i)
    temp = []
    for x in batch_grad:
        temp.append(x[i])
        
    conbine = torch.cat(temp, dim=0)
#     avg_all = torch.sum(conbine, 0).unsqueeze(0).numpy()
    avg_all = torch.nn.functional.softmax(torch.sum(conbine, 0).unsqueeze(0), dim=1).numpy()

    plt.figure(figsize=(5, 1))
    sns.heatmap(avg_all, cmap = cmap)

## sample wise

In [186]:
samples = []
for ind, batch in enumerate(vali_loader):
    print(ind, batch[0].shape)
    exp.model.eval()
    seq_x, seq_y, seq_x_mark, seq_y_mark = map(lambda x: x.float().to(device), batch)
    
    exp.model.train()
    seq_x.requires_grad = True
    
    if 'moco' in args.model or "cost" in args.model:
        prediction_y, _, contrast_loss  = exp.model(seq_x, seq_x_mark)
    else:
        prediction_y, _ , _ = exp.model(seq_x)

    seq_y = seq_y[:,-args.pred_len:, :]

    buff = []
    for sample in range(seq_y.shape[0]):
        for i in range(0, 10):
            tmp = {}
            pred = prediction_y[sample, i, :]
            label = seq_y[sample, i, :]
            
            tmp["step"] = i+1
            tmp["input"] = seq_x[sample, :, :].cpu().detach()
            tmp["pred"] = prediction_y[sample, :, :].cpu().detach()
            tmp["label"] = seq_y[sample, :, :].cpu().detach()
            
            
            loss = criterion(pred, label)
#             print(loss)
            
            tmp["loss"] = float(loss.cpu().detach())
            
            exp.model.zero_grad()

            if i == args.pred_len-1 and sample == seq_y.shape[0]-1:
                loss.backward(retain_graph=False)
            else:
                loss.backward(retain_graph=True)

            data_grad = seq_x.grad.detach()

            end = args.seq_len - args.pred_len + i + 1
            begin = 0 #end-args.pred_len

            tmp["gradient"] = torch.abs(data_grad[sample, begin: end, :]).cpu().detach()
            
            samples.append(tmp)

    if ind == 10:
        break


0 torch.Size([32, 48, 7])
1 torch.Size([32, 48, 7])
2 torch.Size([32, 48, 7])
3 torch.Size([32, 48, 7])
4 torch.Size([32, 48, 7])
5 torch.Size([32, 48, 7])
6 torch.Size([32, 48, 7])
7 torch.Size([32, 48, 7])
8 torch.Size([32, 48, 7])
9 torch.Size([32, 48, 7])
10 torch.Size([32, 48, 7])


In [187]:
len(samples)

3520

In [188]:
set([x["step"] for x  in samples])

{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}

In [189]:
mode = "min"
tg_index = 300

In [None]:
for ind, sp in enumerate(samples):
    if mode == "min":
        if sp['loss'] == min([x['loss'] for x in samples]):
            print(ind, sp['loss'])
            break
    elif mode == "max":
        if sp['loss'] == max([x['loss'] for x in samples]):
            print(ind, sp['loss'])
            break
    elif mode == "index":
        if ind == tg_index:
            print(ind, sp['loss'])
            break

In [None]:
cmap = sns.cm.rocket_r

for dim in range(sp['gradient'].shape[-1]):
    sqence_line(sp, dim)
#     dim_grad = sp['gradient'][:, dim].unsqueeze(0).numpy()
    dim_grad = torch.nn.functional.softmax(sp['gradient'][:, dim].unsqueeze(0), dim=1).numpy()

#     plt.figure(figsize=(10, 1))
#     sns.heatmap(dim_grad, cmap = cmap, linewidths=0.8, cbar=False)
#     break
    

## print all dimensions in the input in a single picture

In [190]:
out_dir = "./output_receptive_field"

In [191]:

def group_plot(sp, ind, output_dir):
    feature_size = min(sp['gradient'].shape[-1], 20)

    fig, axs = plt.subplots(feature_size, figsize=(15, 3*feature_size))

    for dim in range(feature_size):
        input_x = sp["input"].detach().numpy()
        target_label = sp["label"][:sp["step"]+1, :].numpy()

        predict = float(sp["pred"][sp["step"], dim])

        concat_ts = np.concatenate((input_x, target_label), axis=0)
        concat_ts = list(concat_ts[:, dim])

        x = [ind for ind, y in enumerate(concat_ts)]
        y = [y for ind, y in enumerate(concat_ts)]

        x.append(len(concat_ts)-1)
        y.append(predict)
        # y2point = input_x[:, ]

        input_end = args.seq_len - args.pred_len + sp["step"] + 1
        input_unseen_end = args.seq_len + sp["step"] + 1

        dim_grad = torch.nn.functional.softmax(sp['gradient'][:, dim].unsqueeze(0), dim=1).numpy()
        dim_grad = np.repeat(dim_grad, 2, axis=0)
        axs[dim].imshow(dim_grad, cmap='GnBu', alpha=1,  origin="upper", interpolation='nearest', aspect='auto', extent=(-0.5, input_end-0.5, min(y)-0.1, max(y)+0.1))
        #     np.pad(dim_grad,((0, 0), (0, 23)), 'constant')
        axs[dim].plot(x[:input_end], y[:input_end], "ko-", label= "Input")

        axs[dim].plot(x[input_end-1: input_unseen_end], y[input_end-1:input_unseen_end], color="grey", linestyle=":",  marker = "o", linewidth=1, markersize=3, label= "Unseen")

        axs[dim].plot(x[input_end-1], y[input_end-1], "ko-")

        axs[dim].plot(x[input_unseen_end-1:-1], y[input_unseen_end-1:-1], "ro", markersize=10, label= "Label")
        axs[dim].plot(x[-1], y[-1], color="orange", marker = "s", linestyle = 'None', markersize=8, label= "Prediction")

        axs[dim].set_xticks([x for x in range(0, len(x), 2)])
        axs[dim].set_ylim(ymax = max(y)+0.1, ymin = min(y)-0.1)

        axs[dim].grid(color='grey', linestyle='--', linewidth=0.5)
        #     plt.yticks([round(max(y) + 0.5 * y, 2) for y in range(int((max(y)-min(y))/0.5)+1)])

        axs[dim].legend(loc='best') #lower right


    fig.suptitle(f"Receptive Field of {args.model}. \n Dataset: {args.data}, Sample MSE = {str(sp['loss'])[:6]} at {sp['step']}th forecasting step. \nInput length = {args.seq_len}, Prediction length = {args.pred_len}. \n\n", fontsize=16, wrap=True)

    fig.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)

    file_name = f"{output_dir}/model-{args.model}_data-{args.data}_feature-{feature_size}_val-sample{ind}_pred-at-{sp['step']}_inputleng-{args.seq_len}_predleng-{args.pred_len}.png"
    fig.savefig(file_name, bbox_inches='tight')
    plt.close()


In [192]:
len(samples)

# candidate_ind = [x for x in range(len(samples)) if random.uniform(0, 1) > 0.97][:100]

3520

In [160]:
print(candidate_ind)

[19, 43, 92, 108, 139, 166, 208, 229, 288, 302, 319, 377, 465, 474, 522, 599, 600, 634, 757, 837, 862, 887, 892, 964, 997, 1017, 1041, 1052, 1063, 1107, 1288, 1318, 1326, 1388, 1427, 1433, 1472, 1515, 1603, 1610, 1624, 1629, 1819, 1893, 1897, 1911, 1915, 1930, 2096, 2159, 2291, 2321, 2332, 2340, 2345, 2347, 2373, 2521, 2573, 2607, 2645, 2666, 2676, 2860, 2862, 2898, 2953, 2976, 3020, 3051, 3105, 3126, 3135, 3190, 3268, 3287, 3306, 3327, 3363, 3407, 3420, 3427, 3494]


In [193]:
path = out_dir + "/" + args.model
if not os.path.exists(path):
    os.mkdir(path)

In [194]:
import random

for ind, sp in enumerate(samples):
    if ind in candidate_ind:
        group_plot(sp, ind, path)

In [None]:
!ls output_receptive_field/ | wc -l


In [122]:
! rm -r output_receptive_field/tcn/
