In [2]:
# test_model
import torch
from utils.test import test
from eval import compute_err

def test_model(model_dir = "results_2025_1/test_0106_all/exp_2_50_all/dataforgood/dynst_7_1_w7_s2_20250106163802/model_EN_best.pth"):
    # model_dir = "results_2025_1/tests_1209/exp_2_sim_graph_lambda_0/sim/dynst_7_1_w7_s2_20241209105111/model_S3_best.pth"
    res, meta_data, args = test(model_dir, logger_disable=True, device=4)
    ((loss_train, y_real_train, y_hat_train, adj_real_train, adj_hat_train),
                        (loss_val, y_real_val, y_hat_val, adj_real_val, adj_hat_val),
                        (loss_test, y_real_test, y_hat_test, adj_real_test, adj_hat_test)) = [map(lambda x: x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x, r) for r in res['outputs']]
    # err_val, err_test = compute_err(y_hat_val, y_real_val, False), compute_err(y_hat_test, y_real_test, False)
    errs_test = [compute_err(y_hat_test[:, :, i, :], y_real_test[:, :, i, :], False) for i in range(y_hat_test.shape[2])]
    cur = locals()
    # return errs_test
    return {v: cur[v] for v in ("loss_train", "y_real_train", "y_hat_train", "adj_real_train", "adj_hat_train",
                        "loss_val", "y_real_val", "y_hat_val", "adj_real_val", "adj_hat_val",
                        "loss_test", "y_real_test", "y_hat_test", "adj_real_test", "adj_hat_test", "errs_test")}

In [6]:
# 画 散点图

# from utils.utils import matplotlib_chinese
# matplotlib_chinese()
import os, numpy as np
from utils.utils import get_exp_desc
from tqdm.auto import tqdm
# 用 matplotlib 画出散点图。并标出对角线（y=x）
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
fontsize = 14

def plot(ydays, dataset, country_code, node_observation_ratio, filepath, has_inset=False):

    errs = {model_name: test_model(b.path)['errs_test'] for model_name, b in paths[f'o{node_observation_ratio}'].sort_index().loc[(ydays, country_code)].iterrows()}
    maes = {model_name: b.mae for model_name, b in paths[f'o{node_observation_ratio}'].sort_index().loc[(ydays, country_code)].iterrows()}

    model_names = list(errs.keys())

    xs, y = [errs[model_names[i]] for i in range(1, len(model_names))], errs[model_names[0]]

    fig, axes = plt.subplots(len(xs), 1, figsize=(8, 15))
    # fig, axes = plt.subplots(1, len(xs), figsize=(18, 5))

    # fig.suptitle(get_exp_desc(f'({country_code}) ', 7, 1, 7, ydays - 1, node_observation_ratio, 'en'), fontsize=fontsize)

    [a.plot([0, max(x + y)], [0, max(x + y)],
            color='grey', linestyle='-.', linewidth=1.5, alpha=0.7) for x, a in zip(xs, axes)] # x = y

    # [a.scatter(xs[i], y, color='darkgreen', s=50, alpha=0.6) for i, a in enumerate(axes)]
    # 区分对角线两侧的散点颜色
    for i, a in enumerate(axes):
        x = np.array(xs[i])  # 转为 numpy 数组便于比较
        y_array = np.array(y)
        
        mask_above = y_array > x
        mask_below = y_array < x
        mask_equal = ~(mask_below | mask_above)  # 如果有相等的情况
        
        a.scatter(x[mask_above], y_array[mask_above], color='green', s=50, alpha=0.4, label='y > x') # 上三角 (y > x): 浅绿色
        a.scatter(x[mask_below], y_array[mask_below], color='green', s=50, alpha=0.8, label='y < x') # 下三角 (y < x): 深绿色
        if mask_equal.any(): a.scatter(x[mask_equal], y_array[mask_equal], color='grey', s=50, alpha=0.8, label='y = x') # 对角线附近 (y ≈ x): 中间色（可选）

        if has_inset:
            # 创建嵌入子图，显示极值（100-150范围）
            inset = inset_axes(a, width="30%", height="30%", loc='upper right')  # 右上角
            inset.scatter(x[mask_above], y_array[mask_above], color='green', s=50, alpha=0.4)
            inset.scatter(x[mask_below], y_array[mask_below], color='green', s=50, alpha=0.8)
            if mask_equal.any(): inset.scatter(x[mask_equal], y_array[mask_equal], color='grey', s=50, alpha=0.8)

            xlim, ylim = a.get_xlim(), a.get_ylim()
            threshold = np.percentile(np.concatenate([x, y_array]), 90)  # 取90分位数作为阈值
            
            a.set_xlim(xlim[0], threshold * 1.7)
            a.set_ylim(ylim[0], threshold * 1.7)
            inset.spines['top'].set_color('black')
            inset.spines['bottom'].set_color('black')
            inset.spines['left'].set_color('black')
            inset.spines['right'].set_color('black')
            inset.set_xlim(threshold, max(x.max(), y_array.max()) * 1.1)
            inset.set_ylim(threshold, max(x.max(), y_array.max()) * 1.1)
        
        a.tick_params(axis='both', labelsize=fontsize, pad=5)

        # 隐藏顶部和右侧边框
        a.spines['top'].set_visible(False)
        a.spines['right'].set_visible(False)
        a.spines['bottom'].set_visible(False)
        a.spines['left'].set_visible(False)

        # 绘制箭头
        a.annotate('', xy=(1.1, 0), xycoords='axes fraction', xytext=(-0.1, 0), textcoords='axes fraction',
                   arrowprops=dict(facecolor='black', shrink=0.05, width=0.5, linewidth=0.5), zorder=5)
        a.annotate('', xy=(0, 1.1), xycoords='axes fraction', xytext=(0, -0.1), textcoords='axes fraction',
                   arrowprops=dict(facecolor='black', shrink=0.05, width=0.5, linewidth=0.5), zorder=5)

    [a.set_xlabel("MAE of %s" % model_names[i + 1].upper(), fontsize=fontsize) for i, a in enumerate(axes)]
    [a.set_ylabel("MAE of VAST-GNN", fontsize=fontsize) for a in axes]
    # [a.set_ylabel("MAE of %s" % model_names[0].capitalize(), fontsize=fontsize) for a in axes]

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.5, hspace=None) # 调整水平间距
    fig.savefig(filepath, transparent=True)
    plt.close()
    # plt.show()

DIR_NAME = "visualization/results_plot_maes"

def draw_main(dataset_name, ratios, ydays, countries):
    qbar = tqdm(total=len(ratios) * len(ydays) * len(countries))
    for ratio in ratios:
        os.makedirs(DIR_NAME, exist_ok=True)
        for y in ydays:
            for country in countries:
                filename = "_".join(map(str, (y, dataset_name, country, ratio))) + ".pdf"
                plot(y, dataset_name, country, ratio, os.path.join(DIR_NAME, filename))
                qbar.update()

from best_results import paths
draw_main(dataset_name="dataforgood", ratios=[50, 80], ydays=[3, 7, 14], countries=["EN", "FR", "IT", "ES", "NZ", "JP"])
from best_results import paths_flunet as paths
draw_main(dataset_name="flunet", ratios=[50], ydays=[3], countries=["h1n1", "h3n2", "BY", "BV"])

  0%|          | 0/36 [00:00<?, ?it/s]



  0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# import torch
# from utils.test import test
# from eval import compute_err


# def test_model(model_path):
#     res, meta_data, args = test(model_path)
#     (
#         (loss_train, y_real_train, y_hat_train, adj_real_train, adj_hat_train),
#         (loss_val, y_real_val, y_hat_val, adj_real_val, adj_hat_val),
#         (loss_test, y_real_test, y_hat_test, adj_real_test, adj_hat_test),
#     ) = [
#         map(lambda x: x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x, r)
#         for r in res["outputs"]
#     ]
#     err_test =  compute_err(y_hat_test, y_real_test, False)
#     return f"{err_test:.3f}"


In [None]:
# # 重新导入包 importlib
# if 'best_results' in globals():
#     print('重新导入包')
#     import importlib
#     importlib.reload(best_results)
# else:
#     import best_results

# paths = best_results.paths

# import zipfile
# from tqdm.auto import trange

# zip_file_name = 'best_results.zip'

# with zipfile.ZipFile(zip_file_name, 'w') as zipf:
#     qbar = trange(len(paths['o50']))
#     for i in qbar:
#         data = paths['o50'].iloc[i]
#         qbar.set_description(" ".join(map(str, data.name)))
#         zipf.write(data.path, f"y{'_'.join(map(str, data.name))}.pth")


In [None]:
# from best_results import paths
# paths[f'o{node_observation_ratio}'].sort_index().loc[(ydays, country_code)].iterrows()
import re
paths['o50']['path'].iloc[i * 4 - 1]
dataset, model, ydays_minus_1, country = re.search(r".*/(.*?)/(.*?)_\d+_\d+_w\d+_s(\d+)_.*?/model_(.*)_best.pth", paths['o50']['path'].iloc[i * 4 - 1]).groups()

In [None]:
from best_results import paths

paths['o50']['path'].iloc[3]

len(paths['o50']['path']) / 4

# res = test_model(path)
# res['adj_hat_test']

In [None]:
# 画 矩阵热图

import numpy as np, re
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

from best_results import paths
# from best_results import paths_flunet as paths

_index = 0
path = paths['o50']['path'].iloc[_index * 4 - 1]
dataset, model, ydays_minus_1, country = re.search(r".*/(.*?)/(.*?)_\d+_\d+_w\d+_s(\d+)_.*?/model_(.*)_best.pth", path).groups()
res = test_model(path)
adj_real_test, adj_hat_test = res['adj_real_test'], res['adj_hat_test']

# 创建掩码函数
def create_alpha_mask(matrix):
    n = matrix.shape[0]
    alpha = np.ones((n, n))
    np.fill_diagonal(alpha, 0)  # 设置对角线为透明
    return alpha

# 创建交互式控件
# widgets.SelectionSlider()
index_slider = widgets.IntSlider(value=2, min=0, max=len(paths['o50']['path']) / 4 - 1, step=1, description='Index')
batch_slider = widgets.IntSlider(value=6, min=0, max=adj_real_test.shape[0] - 1, step=1, description='Batch')
day_slider = widgets.IntSlider(value=1, min=0, max=adj_real_test.shape[1] - 1, step=1, description='Day')

# 绘图函数
def update(index, batch, day):
    global _index, path, dataset, model, ydays_minus_1, country, res, adj_real_test, adj_hat_test
    if _index != index:
        path = paths['o50']['path'].iloc[index * 4 - 1]
        dataset, model, ydays_minus_1, country = re.search(r".*/(.*?)/(.*?)_\d+_\d+_w\d+_s(\d+)_.*?/model_(.*)_best.pth", path).groups()
        res = test_model(path)
        adj_real_test, adj_hat_test = res['adj_real_test'], res['adj_hat_test']
        batch_slider.max = adj_real_test.shape[0] - 1
        day_slider.max = adj_real_test.shape[1] - 1
        batch_slider.value = 0
        day_slider.value = 0
        _index = index
    
    real_matrix = adj_real_test[batch, day]
    hat_matrix = adj_hat_test[batch, day]
    
    # 创建透明掩码
    alpha_mask_real = create_alpha_mask(real_matrix)
    alpha_mask_hat = create_alpha_mask(hat_matrix)
    
    # 绘制图像
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"[{'_'.join((dataset, model, ydays_minus_1, country))}] Index {index}, Batch {batch}, Day {day}", fontsize=16)

    hat_matrix = (hat_matrix - hat_matrix.min()) / (hat_matrix.max() - hat_matrix.min())
    # 绘制预测值图像
    im0 = axes[0].imshow(hat_matrix, cmap='magma', interpolation='nearest', vmin=hat_matrix[hat_matrix > 0].min(), vmax=hat_matrix[hat_matrix > 0].max())
    fig.colorbar(im0, ax=axes[0], label='Value')
    axes[0].imshow(hat_matrix, cmap='magma', interpolation='nearest', alpha=alpha_mask_hat)  # 应用透明掩码
    axes[0].axis('off')
    axes[0].set_title('Ours')

    # 绘制真实值图像
    im1 = axes[1].imshow(real_matrix, cmap='magma', interpolation='nearest')
    fig.colorbar(im1, ax=axes[1], label='Value')
    axes[1].imshow(real_matrix, cmap='magma', interpolation='nearest', alpha=alpha_mask_real)  # 应用透明掩码
    axes[1].axis('off')
    axes[1].set_title('GT')
    
    plt.show()


interactive_plot = widgets.interactive(update, index=index_slider, batch=batch_slider, day=day_slider)
display(interactive_plot)


In [None]:
import numpy as np

# 参数设置
num_nodes = 70  # 节点数
num_dates = 120  # 时间步数
delta_t = 1.0  # 时间步长
a, b, c = 0.074, 0.130, 0.01  # 动力学参数
gamma = 0.05  # 网络调整系数
max_cases = 1000  # 假设一个病例数上限

# 初始化节点和时间
nodes = [f"R{i:03}" for i in range(num_nodes)]
dates = [f"D{i:03}" for i in range(num_dates)]

# 初始化病例数增长量和网络结构
growth = np.zeros((num_dates, num_nodes, 1))  # 每日增长量
cases = np.zeros((num_dates, num_nodes, 1))  # 累积病例数
adjs = np.random.rand(num_dates, num_nodes, num_nodes)  # 随机生成的邻接矩阵
np.fill_diagonal(adjs[0], 0)  # 对角线置零，防止自环

# 设置初始病例增长量
growth[0] = np.random.randint(1, 10, size=(num_nodes, 1))
cases[0] = growth[0]  # 初始病例数为增长量

# 定义sigmoid函数
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# 动力学公式，计算增长量
def compute_growth(x, adj, a, b, c, max_cases):
    x_diff = x.T - x  # x_j - x_i
    propagation = np.sum(adj * sigmoid(x_diff), axis=1, keepdims=True)  # 邻接矩阵加权传播
    growth_rate = a * x + b * propagation - c * x**2  # 基础增长动力学
    growth_rate = np.maximum(growth_rate, 0)  # 确保增长量非负
    
    # 引入Sigmoid增长限制
    growth_limit_factor = 1 - sigmoid((x - max_cases) / 100)  # 控制增长到上限的限制项
    growth_rate *= growth_limit_factor  # 应用增长限制
    return growth_rate

# 网络更新规则
def update_adjacency(adj, x, gamma):
    x_diff = x - x.T  # x_i - x_j
    adj += gamma * np.exp(-x_diff**2 / 2)  # 基于高斯核调整权重
    np.fill_diagonal(adj, 0)  # 确保对角线仍为0
    return np.clip(adj, 0, 1)  # 限制权重范围在[0, 1]

# 动力学迭代
for i_date in range(1, num_dates):
    # 当前时间步病例数和网络结构
    x_t = cases[i_date - 1]
    adj_t = adjs[i_date - 1]
    
    # 计算每日增长量
    growth_rate = compute_growth(x_t, adj_t, a, b, c, max_cases)
    growth[i_date] = delta_t * growth_rate
    
    # 累加增长量得到病例数
    cases[i_date] = cases[i_date - 1] + growth[i_date]
    
    # 更新网络结构
    adjs[i_date] = update_adjacency(adj_t, x_t, gamma)

import matplotlib.pyplot as plt
plt.plot(range(cases.shape[0]), cases[:, 5])
plt.show()

In [None]:
import pickle
# with open("dataset_cache_bak/dataforgood_x7_y1_w7_s2_m50.bin", "rb") as f: res = pickle.load(f)
with open("dataset_cache/sim_x7_y1_w7_s2_m50.bin", "rb") as f: res = pickle.load(f)
features, cases, adjs = res['data']["SIM1"][1]

cases = cases[:, 6].cumsum(0)

import matplotlib.pyplot as plt
plt.plot(range(cases.shape[0]), cases)
plt.show()


In [None]:
import numpy as np
# 参数设置
num_nodes = 70  # 节点数
num_dates = 64  # 时间步数
delta_t = 1.0  # 时间步长
a, b = 0.064, 0.01  # 动力学参数

# 初始化节点和时间
nodes = [f"R{i:03}" for i in range(num_nodes)]
dates = [f"D{i:03}" for i in range(num_dates)]

# 初始化病例数和邻接矩阵
cases = np.zeros((num_dates, num_nodes, 1))  # 病例数，形状为 (时间, 节点, 1)
adjs = np.random.rand(num_dates, num_nodes, num_nodes)  # 随机生成的邻接矩阵

# 设置初始病例数
cases[0] = np.random.randint(0, 10, size=(num_nodes, 1))

# 定义动力学公式 dx/dt
def compute_dx_dt(x, adj, a, b):
    """
    计算 dx/dt = a * x + b * sum(adj * sigmoid(x_j - x_i))
    """
    # 计算 sigmoid(x_j - x_i)
    x_diff = x.T - x  # 广播减法
    sigmoid = 1 / (1 + np.exp(-x_diff))    
    # 计算加权求和部分
    weighted_sum = np.sum(adj * sigmoid, axis=1, keepdims=True)
    # 返回 dx/dt
    return a * x + b * weighted_sum

# 动力学迭代
for i_date in range(1, num_dates):
    # 取前一天的病例数和邻接矩阵
    x_t = cases[i_date - 1]
    adj_t = adjs[i_date - 1]
    # 计算 dx/dt
    dx = compute_dx_dt(x_t, adj_t, a, b)
    # 使用欧拉法更新病例数
    cases[i_date] = x_t + delta_t * dx

import matplotlib.pyplot as plt
plt.plot(range(num_dates), cases[:, 0])
plt.show()
','.join(map(str, np.array(cases[0], dtype=int).squeeze(-1)))

In [None]:
# 针对 dynst 实验 剥离出 lambda
import os, re
res = {}
for dirname, dirpath, files in os.walk('/home/hbj/workspace/vscode_workspace/lp/workspace/results/tests_1129/exp_3_dynst_lambdas/dataforgood'):
    if len(files) != 2 or not all([f.endswith('.txt') for f in files]): continue
    # 从 dirname 获取 shift, country，从 args.txt 获取 lambda，从 log.txt 获取 epoch 和 MAE
    match = re.search(r"s(\d+)_.*?/(.*)", dirname)
    shift, country = match.groups() if match else (None, None)
    with open(os.path.join(dirname, 'args.txt')) as f:
        lambda_ = f.readlines()[13].split()[1:]
        assert all([l == lambda_[0] for l in lambda_])
        lambda_ = lambda_[0]
    with open(os.path.join(dirname, 'log.txt')) as f:
        lines = f.readlines()
        epochline = [l for l in lines if "最小 val loss (epoch" in l]
        if len(epochline) < 1: continue
        epoch = re.search(r"epoch (.*)\)", epochline[0]).groups()[0]
        maeline = [l for l in lines if "[err(val/test)]" in l]
        if len(maeline) < 2: continue
        mae = re.search(r"\[err\(val/test\)\] .*?/(.*?),", maeline[-1]).groups()[0]
    res.update({(lambda_, shift, country): (mae, epoch)})
# print(sorted(res.keys(), key=lambda x: (x[0], int(x[1]), {"England": 0, "France": 1, "Italy": 2, "Spain": 3}.get(x[2], 4))))

_s = []
for l in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    __s = []
    for s in [2, 6, 13]:
        for c in ["England", "France", "Italy", "Spain"]:
            l, s, c = map(str, [l, s, c])
            if (l, s, c) in res: __s.append(res[(l, s, c)])
            else: __s.append(('-', '-'))
    _s.append(__s)
print('\n'.join(['\t'.join(['\t'.join(___s) for ___s in __s]) for __s in _s]))

In [None]:
# python main.py --country England --shift 6  --model mpnn_lstm --result-dir tests_1119 --exp 6 --device 3
# 三个尺度 四个国家 连个模型
res = ''
i = 0
devices = (7, 8, 9)
# for device, model in zip((7, 8, 9), ('mpnn_lstm', 'lstm', 'dynst')):
# for model in ('dynst',):
for shift in [2, 6, 13]:
    for country in [f'SIM{i}' for i in range(5)]:
        res += f'python main.py --country {country:7s} --shift {shift:2} --dataset sim --result-dir tests_1209 --exp 2_sim_graph_lambda_0 --device {devices[i % len(devices)]}\n'
        i += 1
# print('\n'.join(sorted(res.split('\n')[:-1], key=lambda x: int(x[-1]))))
print(semi(res))

In [None]:
import re, math

dataset = 'flunet'

node_observed_ratios = [50, 80]
# countries = ["England", "France", "Spain", "Italy"]
# countries = ["Japan"]
countries = ["h1n1", "h3n2", "BV", "BY"]
shifts = [2, 6, 13]

# 单独国家和尺度 9个 graph-lambda
def gen_lambda(mode, devices, result_dir, commoncmd = ""):
    assert mode in ['baseline', 'dynst', 'baselinedynst']
    keys = tuple((country, shift) for country in countries for shift in shifts)
    res = ''
    i = 0
    for country, shift in keys:
        assert country in countries and shift in shifts
        for node_observed_ratio in node_observed_ratios:
            if 'baseline' in mode:
                # baselines
                res += f'python main.py --dataset {dataset} {commoncmd} --country {country} --shift {shift:2} --result-dir {result_dir} --exp 1_baselines_mpnn_lstm_{node_observed_ratio} --node-observed-ratio {node_observed_ratio} --model mpnn_lstm        --device {devices[i % len(devices)]}\n'
                i += 1
                res += f'python main.py --dataset {dataset} {commoncmd} --country {country} --shift {shift:2} --result-dir {result_dir} --exp 1_baselines_mpnn_tl_{node_observed_ratio}   --node-observed-ratio {node_observed_ratio} --model mpnn_lstm --maml --device {devices[i % len(devices)]}\n'
                i += 1
                res += f'python main.py --dataset {dataset} {commoncmd} --country {country} --shift {shift:2} --result-dir {result_dir} --exp 1_baselines_lstm_{node_observed_ratio}      --node-observed-ratio {node_observed_ratio} --model lstm             --device {devices[i % len(devices)]}\n'
                i += 1
            if 'dynst' in mode:
                # 按 graph-lambda 跑实验
                for num_graph_lambda in range(10):
                    for i_hp in range(4):
                        exp_name = f"graph_lambda_{int(num_graph_lambda)}" + (("_no_graph" if i_hp & 1 else "") + ("_no_virtual_node" if i_hp & 2 else "") if i_hp else "")
                        exp_hyperparams = (" --no-graph" if i_hp & 1 else "") + (" --no-virtual-node" if i_hp & 2 else "")
                        res += f'python main.py --dataset {dataset} {commoncmd} --country {country} --shift {shift:2} --result-dir {result_dir} --exp 2_{node_observed_ratio}_{exp_name:39} --node-observed-ratio {node_observed_ratio} --graph-lambda {num_graph_lambda / 10} {exp_hyperparams:29} --device {devices[i % len(devices)]}\n'
                        i += 1
            if not ('baseline' in mode or 'dynst' in mode):
                raise NotImplementedError(f"这里不应该被执行 {mode} {'baseline' in mode}")

    sorted_res = sorted([x for x in res.split('\n') if not x == ''], key=lambda x: int(re.search("device +(.*)", x).groups()[0]))
    return '\n'.join(sorted_res)

def semi(res, num_works_per_device=2):
    # 同一个device串行执行几个程序

    grouped_commands = {}
    for command in res.split("\n"):
        if command == '': continue
        device = re.search(r"device (.*)", command).groups()[0]
        
        if device in grouped_commands:
            grouped_commands[device].append(command)
        else:
            grouped_commands[device] = [command]

    # 2. 对每个分组进行处理
    semicolon_joined_commands = []
    for device, device_commands in grouped_commands.items():
        works = math.ceil(len(device_commands) / num_works_per_device)
        for i in range(0, len(device_commands), works):
            semicolon_joined_commands.append("; ".join(device_commands[i:i+works]))

    return '\n' * 2 + '\n'.join(semicolon_joined_commands) + '\n' * 2

In [None]:
# 跑实验
result_dir = "test_0106_all"
devices = [3, 7, 8, 9]
i = 0
res = ""
for node_observed_ratio in [50, 80]:
    for shift in [2, 6, 13]:
        for country in ["England", "France", "Italy", "Spain"]:
            res += f"python main.py --dataset dataforgood --country {country} --shift {shift:2} --result-dir {result_dir} --exp 2_{node_observed_ratio}_all --node-observed-ratio {node_observed_ratio} --device {devices[i % len(devices)]}\n"
            i += 1
        res += f"python main.py --dataset japan --country Japan --shift {shift:2} --result-dir {result_dir} --exp 2_{node_observed_ratio}_all --node-observed-ratio {node_observed_ratio} --device {devices[i % len(devices)]}\n"
        i += 1

print(semi(res, 4))

In [None]:


date_result_dir = '0209'
# res = gen_lambda(("Italy", 2))
# res = gen_lambda(("Italy", 2), ("Spain", 2), ("England", 6), ("France", 6), ("England", 13), ("Spain", 13))
# res = gen_lambda('dynst', [2, 3, 4, 7, 8, 9], result_dir)
# res = gen_lambda('baseline', [9], result_dir, "--seed 5 --gendata")
res = '\n'.join([gen_lambda('baselinedynst', [2, 4, 5, 6, 7, 8, 9], f"tests_{date_result_dir}_seed{i}", f"--seed-dataset {i}") for i in range(6, 11)])
# print(res)
print(semi(res, 3))

In [None]:
import os, re
from show_result import show_result

textdir, exp = "tests_0205_seed1", 2
baseline_orders = ['lstm', 'mpnn_lstm', 'mpnn_tl']
ablation_orders = ['no_graph', 'no_virtual_node', 'no_graph_no_virtual_node']

def sort_key(e):
    # return 1
    match = re.search(r"exp_\d+_(\d+)_graph_lambda_(\d+)_?([a-z_]+)?", e)
    if match:
        _node_observed_ratio, _lambda, _ablation = match.groups()
        return 1, int(_node_observed_ratio), int(_lambda), {e: i for i, e in enumerate(ablation_orders)}.get(_ablation, -1)
    else:
        _baseline, _node_observed_ratio = re.search(r"baselines_(.*)_(\d+)", e).groups()
        return 0, int(_node_observed_ratio), baseline_orders.index(_baseline)

# show_result(f"results/tests_{textdir}/exp_{exp}", mode=0)
res = []
for d in sorted(os.listdir(f"results/{textdir}"), key=sort_key):
    if not f'exp_{exp}' in d: continue
    # res.append(d)
    res.append(show_result(os.path.join(f"results/{textdir}", d), mode=1))

res1 = '\n'.join(res)
res2 = re.sub(r'\n+', '\n', res1.replace('''flunet
[err_test] h1n1,h3n2,BV,BY

7->1 (w7s2) | 7->1 (w7s6) | 7->1 (w7s13)''', ''))
# res_str = '\n'.join([_r for r in res for _r in r.split('\n') if len(_r) > 1 and (_r[0] == '-' or _r[0].isdigit() and _r[1] != '-')])
# print(res_str)
print(res2)

In [None]:
from utils.data_process.dataforgood import load_data
from utils.args import get_parser, process_args
args = get_parser().parse_args()
args.country = 'NewZealand'
args = process_args(args, False)
meta_data = load_data(args.dataset_cache_dir, args.data_dir, args.dataset, args.batch_size, args.xdays, args.ydays, args.window, args.shift, args.train_ratio, args.val_ratio, 1, enable_cache = True)
for c in meta_data["dates"]:
    assert meta_data["dates"][c] == sorted(meta_data["dates"][c])
    print(f"{c:10}", f"{len(meta_data['regions'][c]):3}", meta_data["dates"][c][0], meta_data["dates"][c][-1])

In [None]:
re.search(r"exp_\d+_\d+_graph_lambda_(\d+)_?([a-z_]+)?", "exp_2_80_graph_lambda_3").groups()

In [None]:
res27 = [r.split('\t') for r in res_str0.split('\n')]
res28 = [r.split('\t') for r in res_str.split('\n')]
res = [[''] * 24 for _ in range(30)]  # 创建30个独立的列表

for i in range(30):
    for j, (r27, r28) in enumerate(zip(res27[i], res28[i])):
        if r27 == r28: res[i][j] = r27
        else:
            if r27 == '-':
                res[i][j] = r28
            elif r28 == '-':
                res[i][j] = r27
            else:
                print(i, j, r27, r28)
                res[i][j] = r28

_s = '\n'.join(['\t'.join(r) for r in res])
print(_s)

In [None]:
print('''-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	10.572	342	18.235	6	5.091	402	1.492	50	-	-	-	-	5.315	5	-	-	-	-	33.003	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.44	347	18.264	6	4.843	350	1.597	47	-	-	-	-	5.815	492	-	-	-	-	33.028	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.119	389	18.275	6	4.826	431	1.999	163	-	-	-	-	5.284	5	-	-	-	-	33.006	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	8.286	477	18.225	6	5.078	310	1.485	47	-	-	-	-	5.296	5	-	-	-	-	33.022	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.661	412	26.12	196	5.02	409	1.556	44	-	-	-	-	5.332	5	-	-	-	-	33.019	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.624	341	18.246	6	5.124	454	1.584	50	-	-	-	-	5.587	464	-	-	-	-	33.01	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.591	412	18.115	6	4.904	498	1.509	54	-	-	-	-	5.353	5	-	-	-	-	33.001	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	11.126	477	31.764	301	4.883	341	2.092	444	-	-	-	-	5.309	5	-	-	-	-	33.001	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1
-	-	-	-	9.738	4	17.935	6	5.205	220	1.938	76	-	-	-	-	5.151	5	-	-	-	-	32.895	1
-	-	-	-	9.205	295	18.193	6	5.133	491	1.517	54	-	-	-	-	5.256	5	-	-	-	-	33.026	1
-	-	-	-	9.457	4	17.918	6	5.743	159	1.795	65	-	-	-	-	5.043	5	-	-	-	-	32.972	1''')

In [None]:
# raise OSError("这是检查 mask 的地区是否一致的函数")
import pickle as p, os
meta_datas = []
for f in os.listdir("dataset_cache"):
    # 假设 f 是文件路径
    with open(os.path.join("dataset_cache", f), "rb") as file:
        meta_datas.append(p.load(file))

# [[list(meta_datas[i]["selected_indices"][name]) for i in range(len(meta_datas))] for name in meta_datas[0]["country_names"]]
for name in meta_datas[0]["country_names"]:
    _l = [list(meta_datas[i]["selected_indices"]["England"]) for i in range(len(meta_datas))]
    print(name, all([_l[i] == _l[0] for i in range(1, len(_l))]))

In [None]:
raise NotImplementedError('执行前慎重')
import os, torch
for dirpath, dirnames, files in os.walk("results/tests_1015"):
    if not len(files): continue
    for file in files:
        if file.endswith(".pth"):
            _model = torch.load(os.path.join(dirpath, file))
            _state_dict = _model.state_dict()
            os.rename(os.path.join(dirpath, file), os.path.join(dirpath, file + '.bak'))
            torch.save(_state_dict, os.path.join(dirpath, file))

In [None]:
# import requests
# from lxml import html
# from tqdm import tqdm

# # 网址模板
# url_template = "https://findthatpostcode.uk/areas/{}.html"

# # 你要查询的行政区划代码列表
codes = ['E06000001', 'E06000002', 'E06000003', 'E06000004', 'E06000005', 'E06000006', 'E06000007', 'E06000008', 'E06000009', 'E06000011', 'E06000012', 'E06000013', 'E06000014', 'E06000015', 'E06000016', 'E06000017', 'E06000018', 'E06000020', 'E06000021', 'E06000022', 'E06000024', 'E06000026', 'E06000027', 'E06000030', 'E06000031', 'E06000032', 'E06000033', 'E06000034', 'E06000035', 'E06000036', 'E06000037', 'E06000038', 'E06000039', 'E06000040', 'E06000041', 'E06000042', 'E06000043', 'E06000044', 'E06000045', 'E06000046', 'E06000047', 'E06000049', 'E06000050', 'E06000051', 'E06000052', 'E06000054', 'E06000055', 'E06000056', 'E06000057', 'E06000059', 'E08000001', 'E08000002', 'E08000003', 'E08000004', 'E08000005', 'E08000007', 'E08000008', 'E08000009', 'E08000010', 'E08000011', 'E08000012', 'E08000013', 'E08000014', 'E08000015', 'E08000016', 'E08000017', 'E08000018', 'E08000019', 'E08000021', 'E08000022', 'E08000023', 'E08000024', 'E08000025', 'E08000026', 'E08000027', 'E08000030', 'E08000031', 'E08000032', 'E08000033', 'E08000034', 'E08000035', 'E08000036', 'E08000037', 'E09000001', 'E09000002', 'E09000003', 'E09000004', 'E09000005', 'E09000006', 'E09000008', 'E09000009', 'E09000010', 'E09000015', 'E09000016', 'E09000017', 'E09000018', 'E09000020', 'E09000021', 'E09000022', 'E09000023', 'E09000026', 'E09000027', 'E09000029', 'E10000002', 'E10000003', 'E10000006', 'E10000007', 'E10000008', 'E10000011', 'E10000012', 'E10000013', 'E10000014', 'E10000015', 'E10000016', 'E10000017', 'E10000018', 'E10000019', 'E10000020', 'E10000021', 'E10000023', 'E10000024', 'E10000025', 'E10000027', 'E10000028', 'E10000029', 'E10000030', 'E10000031', 'E10000032', 'E10000034']
# # 存储结果的列表
# results = []

# # 遍历每个代码，访问对应的网址
# for code in tqdm(codes):
#     url = url_template.format(code)
#     response = requests.get(url)
#     tree = html.fromstring(response.content)
#     h2_text = tree.xpath('//html/body/main/header/h2/text()')
#     results.append(h2_text)

# # 打印结果
# for result in results:
#     print(result)
# str([i[0][1:-2].strip() for i in results])
names = ['Hartlepool', 'Middlesbrough', 'Redcar and Cleveland', 'Stockton-on-Tees', 'Darlington', 'Halton', 'Warrington', 'Blackburn with Darwen', 'Blackpool', 'East Riding of Yorkshire', 'North East Lincolnshire', 'North Lincolnshire', 'York', 'Derby', 'Leicester', 'Rutland', 'Nottingham', 'Telford and Wrekin', 'Stoke-on-Trent', 'Bath and North East Somerset', 'North Somerset', 'Plymouth', 'Torbay', 'Swindon', 'Peterborough', 'Luton', 'Southend-on-Sea', 'Thurrock', 'Medway', 'Bracknell Forest', 'West Berkshire', 'Reading', 'Slough', 'Windsor and Maidenhead', 'Wokingham', 'Milton Keynes', 'Brighton and Hove', 'Portsmouth', 'Southampton', 'Isle of Wight', 'County Durham', 'Cheshire East', 'Cheshire West and Chester', 'Shropshire', 'Cornwall', 'Wiltshire', 'Bedford', 'Central Bedfordshire', 'Northumberland', 'Dorset', 'Bolton', 'Bury', 'Manchester', 'Oldham', 'Rochdale', 'Stockport', 'Tameside', 'Trafford', 'Wigan', 'Knowsley', 'Liverpool', 'St. Helens', 'Sefton', 'Wirral', 'Barnsley', 'Doncaster', 'Rotherham', 'Sheffield', 'Newcastle upon Tyne', 'North Tyneside', 'South Tyneside', 'Sunderland', 'Birmingham', 'Coventry', 'Dudley', 'Walsall', 'Wolverhampton', 'Bradford', 'Calderdale', 'Kirklees', 'Leeds', 'Wakefield', 'Gateshead', 'City of London', 'Barking and Dagenham', 'Barnet', 'Bexley', 'Brent', 'Bromley', 'Croydon', 'Ealing', 'Enfield', 'Harrow', 'Havering', 'Hillingdon', 'Hounslow', 'Kensington and Chelsea', 'Kingston upon Thames', 'Lambeth', 'Lewisham', 'Redbridge', 'Richmond upon Thames', 'Sutton', 'Buckinghamshire', 'Cambridgeshire', 'Cumbria', 'Derbyshire', 'Devon', 'East Sussex', 'Essex', 'Gloucestershire', 'Hampshire', 'Hertfordshire', 'Kent', 'Lancashire', 'Leicestershire', 'Lincolnshire', 'Norfolk', 'Northamptonshire', 'North Yorkshire', 'Nottinghamshire', 'Oxfordshire', 'Somerset', 'Staffordshire', 'Suffolk', 'Surrey', 'Warwickshire', 'West Sussex', 'Worcestershire']
{codes[i]: names[i] for i in range(len(names))}



In [None]:
import zipfile
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt

from utils.utils import progress_indicator


# 读取地理数据
shapes_zip_path = "data/mapfiles/gadm41_GBR_shp.zip"

with zipfile.ZipFile(shapes_zip_path) as z:
    names = z.namelist()
shps = [n for n in names if n.endswith("shp")]

qbar = progress_indicator(shps, leave=False)
shapes = []

for i in range(len(shps)):
    qbar.set_description(f"正在读取 {shps[i]} ({shapes_zip_path})")
    shapes.append(gpd.read_file(f"zip://{shapes_zip_path}!/{shps[i]}"))
    qbar.update()


In [None]:
shapes[1]

In [None]:
import os
import torch

# 指定你的设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def convert_model_files(root_dir):
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.endswith('.pth.bak'):
                pth_file = os.path.join(dirpath, filename[:-4])
                bak_file = os.path.join(dirpath, filename)
                os.rename(bak_file, pth_file)

# 替换为你的目标目录
target_directory = "results"
convert_model_files(target_directory)


In [None]:
# command generator

def baseline(ratio=70):
    date = f"1023_2_observed_{ratio}"
    countries = ["England,Spain", "France,Italy"]
    shifts = [2, 6, 13]
    models = ["mpnn_lstm", "lstm", "dynst"]
    cmd_pattern = lambda c, s, m: f"python main.py --country {c:{max(map(len, countries))}} --result-dir tests_{date} --shift {s:<{max(map(len, map(str, shifts)))}} --exp 0_EN_ES_{m:{max(map(len, models))}} --model {m:{max(map(len, models))}}  --node-observed-ratio {ratio} --device 0"
    
    cmds = []
    for i_c in range(len(countries)):
        for i_s in range(len(shifts)):
            for i_m in range(len(models)):
                cmds.append(cmd_pattern(countries[i_c], shifts[i_s], models[i_m]))
    cmds = '\n'.join(cmds)
    return cmds

def cmd(exp_num, ratio = 70):
    assert type(exp_num) == int
    date = f"1023_2_observed_70"
    str = f'''
python main.py --country France        --result-dir tests_{date} --shift 2  --exp {exp_num} --node-observed-ratio {ratio} --graph-lambda 0.{int(exp_num)} --device 8
python main.py --country England       --result-dir tests_{date} --shift 6  --exp {exp_num} --node-observed-ratio {ratio} --graph-lambda 0.{int(exp_num)} --device 8
python main.py --country France,Italy  --result-dir tests_{date} --shift 13 --exp {exp_num} --node-observed-ratio {ratio} --graph-lambda 0.{int(exp_num)} --device 9
'''
    return str
print(baseline())

In [None]:
from utils.custom_datetime import datetime
date = "1023_2_observed_70"
print(f"统计 {date} 于 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
from show_result import extract_results, show_result
import os
res = []
for f in os.listdir(f"results/tests_{date}"):
    print("\033[32m> {}\033[0m".format(f))
    dir = f"results/tests_{date}/{f}"
    try:
        res += extract_results(dir)
        show_result(dir)
    except:
        for f in os.listdir(dir):
            print("\033[32m> {}\033[0m".format(f))
            _dir = f"{dir}/{f}"
            res += extract_results(_dir)
            show_result(_dir)
            print()
    print()

In [None]:
res0 = res

In [None]:
raise NotImplementedError
from utils.args import parse_args
from utils.data_process.dataforgood import load_data
from argparse import Namespace
preprocessed_data_dir, data_dir, databinfile, batch_size, xdays, ydays, window, shift, train_ratio, val_ratio, node_observed_ratio = ('data_preprocessed', 'data/dataforgood', 'data_preprocessed/dataforgood_x7_y1_w7_s2_m50.bin', 8, 7, 1, 7, 2, 0.7, 0.1, 1)
args = Namespace(preprocessed_data_dir=preprocessed_data_dir, data_dir=data_dir, databinfile=databinfile, batch_size=batch_size, xdays=xdays, ydays=ydays, window=window, shift=shift, train_ratio=train_ratio, val_ratio=val_ratio, node_observed_ratio=node_observed_ratio)
res = load_data(args, False)

adjs = [v[1][2] for v in res["data"].values()]
nonzeros = [a.count_nonzero((1,2)) for a in adjs]

nodes = [i.shape[1] for i in adjs]

[nonzeros[i] / (nodes[i] * nodes[i]) for i in range(4)]

In [None]:
raise NotImplementedError
import os

def delete_pth_files(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.pth'):
                file_path = os.path.join(root, file)
                os.remove(file_path)
                print(f"Deleted: {file_path}")

# 使用示例
delete_pth_files('results/tests_old')
