# cLSTM Lorenz-96 Demo
- In this notebook, we train a cLSTM model on data simulated from a Lorenz-96 system

In [1]:
import os

os.getcwd()
# 显示当前jupyter启动在哪


'/root/autodl-tmp/Neural-GC-master'

In [2]:
import matplotlib.pyplot as plt
import torch

from models.clstm import cLSTM, train_model_ista

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
# import sys
# import os
# from datetime import datetime
# import logging
# 
# # 获取当前时间作为文件名的一部分
# current_time = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# log_filename = f"log_{current_time}.txt"
# 
# # 确保日志文件夹存在
# log_dir = "logs"
# if not os.path.exists(log_dir):
#     os.makedirs(log_dir)
# 
# log_path = os.path.join(log_dir, log_filename)
# 
# 
# # 创建一个同时写入控制台和文件的类
# class Logger(object):
#     def __init__(self, filename="Default.log"):
#         self.terminal = sys.stdout
#         self.log = open(filename, "w")
# 
#     def write(self, message):
#         self.terminal.write(message)
#         self.log.write(message)
#         self.flush()  # 确保实时写入
# 
#     def flush(self):
#         self.terminal.flush()
#         self.log.flush()
# 
#     def __del__(self):
#         self.log.close()
# 
# 
# # 重定向 sys.stdout 到 Logger 实例
# sys.stdout = Logger(log_path)
# 
# # 示例打印
# print("This is a test log message.")
# sys.stdout = sys.__stdout__  # 恢复到原始的 sys.stdout


In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix

# 指定根文件夹路径
root_path = 'datasets/lorenz/F20'
# Define parameters at the beginning of the file
CONTEXT = 10
# LAMBDA = 5.0
LAMBDA_RIDGE = 1e-2
LEARNING_RATE = 1e-3
MAX_ITERATIONS = 20000
CHECK_EVERY = 50
# 定义固定的 lambda 值
LAMBDA_RANGE = [0.1, 0.3, 0.7, 1.0, 2.0, 3.5, 5.0, 10.0, 15.0, 20.0]

# 遍历文件夹结构
for folder1 in os.listdir(root_path):
    folder1_path = root_path + '/' + folder1
    for file in os.listdir(folder1_path):
        if file.endswith('.npz'):
            file_path = folder1_path + '/' + file
            print("当前处理文件是：" + file_path)
            # 读取.npz文件
            print("Current parameters:")
            print(f"CONTEXT = {CONTEXT},  LAMBDA_RIDGE = {LAMBDA_RIDGE}")
            print(
                f"LEARNING_RATE = {LEARNING_RATE}, MAX_ITERATIONS = {MAX_ITERATIONS}, CHECK_EVERY = {CHECK_EVERY}")
            data = np.load(file_path)
            # 从文件中提取'X'和'GC'数据
            X_np = data['X']
            GC = data['GC']

            # 将X_np转换为torch tensor
            X = torch.tensor(X_np[np.newaxis], dtype=torch.float32, device=device)

            clstm = cLSTM(X.shape[-1], hidden=100).cuda(device=device)
            # Train with ISTA
            # 对每个lambda值进行实验
            for LAMBDA in LAMBDA_RANGE:
                print(f"当前 LAMBDA = {LAMBDA:.4f}")
                train_loss_list = train_model_ista(
                    clstm, X,
                    context=CONTEXT,
                    lam=LAMBDA,
                    lam_ridge=LAMBDA_RIDGE,
                    lr=LEARNING_RATE,
                    max_iter=MAX_ITERATIONS,
                    check_every=CHECK_EVERY
                )

                # Check learned Granger causality
                GC_est = clstm.GC().cpu().data.numpy()

                # 将数组展平，计算fpr和tpr
                GC_flat = GC.flatten()
                GC_est_flat = GC_est.flatten()

                # 计算混淆矩阵
                tn, fp, fn, tp = confusion_matrix(GC_flat, GC_est_flat).ravel()

                # 计算 FPR 和 TPR
                fpr = fp / (fp + tn)
                tpr = tp / (tp + fn)

                print(f"ROC Curve Point: FPR = {fpr:.4f}, TPR = {tpr:.4f}")

                print('True variable usage = %.2f%%' % (100 * np.mean(GC)))
                print('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
                print('Accuracy = %.2f%%' % (100 * np.mean(GC == GC_est)))
                # logger.info('True variable usage = %.2f%%' % (100 * np.mean(GC)))
                # logger.info('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
                # logger.info('Accuracy = %.2f%%' % (100 * np.mean(GC == GC_est)))

                # Make figures
                fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
                axarr[0].imshow(GC, cmap='Blues')
                axarr[0].set_title('GC actual')
                axarr[0].set_ylabel('Affected series')
                axarr[0].set_xlabel('Causal series')
                axarr[0].set_xticks([])
                axarr[0].set_yticks([])

                axarr[1].imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
                axarr[1].set_ylabel('Affected series')
                axarr[1].set_xlabel('Causal series')
                axarr[1].set_xticks([])
                axarr[1].set_yticks([])

                # Mark disagreements
                for i in range(len(GC_est)):
                    for j in range(len(GC_est)):
                        if GC[i, j] != GC_est[i, j]:
                            rect = plt.Rectangle((j, i - 0.05), 1, 1, facecolor='none', edgecolor='red',
                                                 linewidth=1)
                            axarr[1].add_patch(rect)

                plt.show()
# sys.stdout = sys.stdout.terminal

当前处理文件是：datasets/lorenz/F10/time1000/lorenz-169-F10-1000.npz
Current parameters:
CONTEXT = 10,  LAMBDA_RIDGE = 0.01
LEARNING_RATE = 0.001, MAX_ITERATIONS = 20000, CHECK_EVERY = 50
