In [38]:
import numpy as np
import os, sys, torch
# from pytorch_lightning import LightningModule
# import src.utils as utils
from pathlib import Path
# /home/kan/ML_application/s4/outputs/2025-04-21/11-19-18/checkpoints/
torch.set_printoptions(
    threshold=float('inf'),      # すべての要素を表示
    precision=10,                # 小数点以下10桁まで表示
    linewidth=10**4              # 1行に表示する文字数（折り返し防止）
)
%matplotlib inline
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm

fig_size_horizontal=15

def is_array(obj):
    return isinstance(obj, (list, np.ndarray))
def DESN_observer(
        u_ex, dim_rv, x_init=None, win=None, w=None, xinit_seed=0,winseed=0,wseed=0,
        rho=0.5,w_norm=True, sign=1, densin=0.1,density=0.5, activation="tanh", verbose=False):
    dim_u_ex = u_ex.shape[0] # Obtain the dimension of input and output
    Time_leng = u_ex.shape[1]

    if is_array(x_init)==False:
        np.random.seed(xinit_seed)
        x_init=np.random.uniform(-1.0, 1.0, dim_rv)
    if is_array(win)==False:
        np.random.seed(winseed)
        win = np.random.normal(
            loc=0, scale=1.0/dim_rv**0.5, size=(dim_rv, dim_u_ex), ) * (np.random.rand(dim_rv, dim_u_ex)<densin)
        # win = (2*np.random.rand(dim_rv,dim_u_ex)-1) * (np.random.rand(dim_rv, dim_u_ex)<densin)
    if is_array(w)==False:
        np.random.seed(wseed)
        w = (2*np.random.rand(dim_rv,dim_rv)-1) * (np.random.rand(dim_rv,dim_rv)<density)
    
    eig = np.linalg.eigvals(w)
    w = sign*w if dim_rv==1 or w_norm==False else rho*sign*w / np.max( np.abs(eig) )
    # print("w=",w, " eig=",eig)
    # print("win", win.shape, " vars", np.var(win.T[0]), np.var(win.T[1]))
    
    states_series = np.zeros((dim_rv, Time_leng))
    states_series[:,0]=x_init
    for t in range(1,Time_leng):
        # print("time", t)
        # print("internal weight", w.shape)
        # # print(w)
        # print("states_series[:,t-1:t]", states_series[:,t-1:t].shape)
        # print(states_series[:,t-1:t])
        # print("win", win.shape)
        # # print(win)
        # print("u_ex[:, t-1:t]", u_ex[:, t-1:t].shape)
        # print(u_ex[:, t-1:t])
        if activation == "lin":
            states_series[:,t:t+1] = w @ states_series[:,t-1:t] + win @ u_ex[:, t-1:t]
        elif activation == "tanh":
            states_series[:,t:t+1] = np.tanh(w @ states_series[:,t-1:t] + (win @ u_ex[:, t-1:t]) )
        else:
            print("Error!!! No such activation function.")
            sys.exit()    
    return states_series

def rank_svd(A, thresh='No', rcond=1e-15,):
    dim_A = np.min(A.shape)

    if thresh=='N':
        finfo = np.finfo(A.dtype).eps #if finfo is None else finfo  # douBle=2^16
        u_, sigma_, v_ = np.linalg.svd(A, full_matrices=False)
        sigmax = sigma_.max()
        # print("rank_svd", A, sigma_, sigmax)
        rcond_=(sigmax) * dim_A * finfo
        index = np.where(sigma_ > rcond_)[0]
        rank=index.shape[0]
        rcond = rcond_/(sigmax)
    else:
        rank=np.linalg.matrix_rank(A)
    return rank, rcond

def MC_0ex_inv(
        N, X, Z, maxtau=51, Two=10**4, T=10**6,
        thresh='No', rcond=1e-15, debug=True):
    taus = np.arange(maxtau)# for MF
    MF = []
    print("Z and X shapes", Z.shape, X.shape)
    X, Z = X[Two:], Z.reshape(-1, 1)
    Z2 = (Z[Two:].T @ Z[Two:]).reshape(-1)[0]
    print("Z and X shapes", Z.shape, X.shape, " Z2!", Z2)
    B = X.T @ X
    B = B.astype(dtype=np.complex128)
    rank, rcond = rank_svd(B, thresh=thresh, rcond=rcond)
    B_1=np.linalg.pinv(B, rcond=rcond)
    rank_B1, _ = rank_svd(B_1, thresh=thresh)
    print("rankB1", rank_B1)
    for tau in taus[1:]:
        # print(Z[Two-tau:T+Two-tau].shape, X.shape)
        ZX = Z[Two-tau:T+Two-tau].T @ X
        ipc = ZX @ B_1 @ ZX.T /Z2
        ipc=ipc.reshape(-1)[0]
        if debug ==True:
            print(tau,ipc)#, (1-rho**2)*rho**(2*tau-2))
        MF.append(ipc)
    MF = np.array(MF, dtype=np.complex128)
    MC = np.sum(MF)
    if debug ==True:
        print("MC_0ex_inv:")
        print("rank: XX:=%d"%(rank))
        if T < 1000:
            rank_X_XX_X, rcond = rank_svd(X@B_1@X.T, thresh=thresh)
            print("rank: X_XX_X:=%d"%(rank_X_XX_X))
    return MC, MF, rank

def MC_3exNocor(
    N, lams, Two=int(1e4), maxtau=51, thresh='No', rcond=1e-15, debug=False
    ):
    # K=Two# Condition: K <= Two
    # delays = np.arange(0,K)# accuracy for MC
    taus = np.arange(1,maxtau) # for MF
    
    Cs_rho=0
    MF_rho=[]
    B=np.array([[1/(1-lams[i]*lams[j]) for i in range(N)] for j in range(N)])
    B=B.astype(dtype=np.complex128)
    rank, rcond = rank_svd(B, thresh=thresh, rcond=rcond)
    
    B_1=np.linalg.pinv(B, rcond=rcond)
    # print(B) # print(B_1) # print(B@B_1)
    for tau in taus:
        tauind=tau-1
        h=(lams**tauind).reshape((N,1))
        ipc=h.T @ B_1 @ h
        ipc=ipc.reshape(-1)[0]
        if debug:
            print(tauind,ipc)#, (1-rho**2)*rho**(2*tauind-2))
        Cs_rho+=ipc
        MF_rho.append(ipc)
    if debug:
        print(Cs_rho)
    
    return Cs_rho, np.array(MF_rho), rank

# get paramas and MF

In [53]:
def get_params(filepath):
    ckpt = torch.load(filepath, map_location='cpu')
    # print(ckpt.keys())
    state_dict = ckpt['state_dict']
    # print(state_dict.keys())
    
    # for key in state_dict.keys():
    #     print(key)
    lams_arr = []
    log_dt_arr = []
    for rep_layer in range(n_layer):
        # print("rep_layer", rep_layer)
        log_dt = state_dict["model.layers.%d.layer.kernel.kernel.log_dt"%rep_layer]
        log_w_real = state_dict["model.layers.%d.layer.kernel.kernel.log_w_real"%rep_layer]
        w_imag = state_dict["model.layers.%d.layer.kernel.kernel.w_imag"%rep_layer]
        dts = torch.exp(log_dt) # (H)
        # print("dts", dts.shape)
        # print(dts)
        mindt, maxdt, meandt=torch.min(dts), torch.max(dts), torch.mean(dts)
        # print("dts",mindt, maxdt, meandt)
        
        w_real = -torch.exp(log_w_real)
        # print("w_real", w_real.shape)
        # print("w_imag", w_imag.shape)
        w = w_real + 1j * w_imag
        # print("w",w)
        
        dtA = w*dt
        A = torch.exp(dtA)[0] # (H N)
        # print("A", A.shape, A)
        rho = torch.max(torch.abs(A))
        # print("rho", rho)
        
        log_dt_arr.append(log_dt), lams_arr.append(A)
        # w *= rho/np.max(np.abs(lams))
        # dtA = w * dts.unsqueeze(-1)  # (H N)
        # A = torch.exp(dtA) # (H N)
    return log_dt_arr, lams_arr

base_dir = "/home/kan/ML_application/s4/outputs/"
# max_depth = 2
# base_dir_ = Path(base_dir)

dt = [0.01, 1][0]
n_layer = 6
data_dict={}

# for file in base_dir_.glob("*/*/checkpoints/*"):
#     if file.is_file():
#         parent_path = file.parents[1]
#         print("parent_path",parent_path)
#         print("file",file)

files = ["2025-05-13/18-29-34/"]
for file in files:
    dirpath = base_dir+file+"checkpoints/"
    file_paths = []
    for root, dirs, files in os.walk(dirpath):
        for file in files:
            if file.endswith('.ckpt'):
                file_paths.append(os.path.join(root, file))
    
    for filepath in file_paths:
        print(filepath)
        filepath_split = filepath.split("/")
        print(filepath_split)
        log_dt_arr, lams_arr = get_params(filepath)
        print("")
        data_dict[filepath] = [log_dt_arr, lams_arr]
    
    break


/home/kan/ML_application/s4/outputs/2025-05-13/18-29-34/checkpoints/epoch=00-metric=0.5153.ckpt
['', 'home', 'kan', 'ML_application', 's4', 'outputs', '2025-05-13', '18-29-34', 'checkpoints', 'epoch=00-metric=0.5153.ckpt']

/home/kan/ML_application/s4/outputs/2025-05-13/18-29-34/checkpoints/last.ckpt
['', 'home', 'kan', 'ML_application', 's4', 'outputs', '2025-05-13', '18-29-34', 'checkpoints', 'last.ckpt']

/home/kan/ML_application/s4/outputs/2025-05-13/18-29-34/checkpoints/epoch=01-metric=0.5332.ckpt
['', 'home', 'kan', 'ML_application', 's4', 'outputs', '2025-05-13', '18-29-34', 'checkpoints', 'epoch=01-metric=0.5332.ckpt']

/home/kan/ML_application/s4/outputs/2025-05-13/18-29-34/checkpoints/epoch=02-metric=0.6884.ckpt
['', 'home', 'kan', 'ML_application', 's4', 'outputs', '2025-05-13', '18-29-34', 'checkpoints', 'epoch=02-metric=0.6884.ckpt']

/home/kan/ML_application/s4/outputs/2025-05-13/18-29-34/checkpoints/epoch=03-metric=0.7367.ckpt
['', 'home', 'kan', 'ML_application', 's4', 

In [None]:
inputname = ["gauss", "uniform"][1]
shuff = [True, False][1]
ex_input_scale1 = [0, 0.05, 0.1, 0.2, 0.5, 1][2]
ex_input_scale2 = [0, 0.05, 0.1, 0.2, 0.5, 1][0]
densin_shared = [True, False][0]
activation = ['lin', 'tanh'][0]
for rep, (key, val) in enumerate(data_dict.items()):
    nrows, ncols = 1, 3
    axis_wide, axis_high = 8.0, 6.0
    fig_ratio = 1
    wspace, hspace = 0.4, 0.35
    fig_wide = ncols*(axis_wide+wspace)
    fig_high = nrows*(axis_high+hspace)
    fig_wide_size = fig_size_horizontal*fig_ratio
    fig_high_size = fig_size_horizontal*fig_ratio*fig_high/fig_wide
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, 
                        figsize = (fig_wide_size, fig_high_size)
                        )#, subplot_kw=dict(projection="polar"))
    fig.subplots_adjust(wspace=wspace, hspace=hspace)
    title=key.split("/")[-1]
    print(title)
    fig.suptitle(title, x=0.5, y=0.98, fontsize=20)
    # plt.figure(figsize=(10, 6))

    print("key", key)
    dts_arr=val[0]
    lams_arr=val[1]
    # print(key, lams_arr)
    axid = ax[0]
    # axid.set_title('%s'%key.split("/")[0], x= -0.5, y=0.5, )
    for rep_layer in range(n_layer):
        dts = dts_arr[rep_layer]
        dts = dts[np.argsort(np.abs(dts))]
        axid.plot(
            dts,
            # label="num", #+ noisename,
            lw=5,
            )
    axid.grid()
    
    axid = ax[1]
    axid.set_xlabel('indices')
    # axid.set_ylim(-10**3, 10**3)
    # axid.set_ylim(-0, 1.1)
    # axid.set_yticks([0, 0.5, 1.0])
    axid.set_ylabel('eigenvalue')
    axid.grid()
    for rep_layer in range(n_layer):
        lams = np.abs(lams_arr[rep_layer])
        lams = lams[np.argsort(np.abs(lams))]
        
        # print(lams)
        axid.plot(
            lams,
            # label="num", #+ noisename,
            lw=5
            )
    
    axid = ax[2]
    axid.set_xlabel(r'$\tau$')
    # axid.set_ylim(-10**3, 10**3)
    # axid.set_ylim(-0, 1.1)
    # axid.set_yticks([0, 0.5, 1.0])
    axid.set_ylabel('MF')
    axid.grid()
    for rep_layer in range(n_layer):
        lams = lams_arr[rep_layer]
        dim_rv = lams.shape[0]
        print(dim_rv, lams.shape)
        # mc, mf, rank =MC_3exNocor(dim_rv, lams.numpy(), Two=int(1e4), maxtau=51, debug=False)
        
        #################################
        Two = 5*10**2
        T = 10**4
        np.random.seed(2)
        vsigma=1
        u = 2*np.random.rand(Two+T)-1
        u *= ex_input_scale1/np.std(u)
        u_ex = np.array([u])
        # np.random.seed(seedin+1)
        # win = (2*np.random.rand(dim_rv,2)-1) * (np.random.rand(dim_rv,2)<densin)
        # win = np.random.normal(dim_rv,2) * (np.random.rand(dim_rv, 2)<densin)
        # np.random.seed(w_index+3)
        if densin_shared:
            w1 = (2*np.random.rand(dim_rv)-1) * (np.random.rand(dim_rv)<1.0)
            win = np.array([w1]).T
        # else:
        #     win = np.random.normal(
        #         loc=0, scale=1.0/dim_rv**0.5, size=(dim_rv, 2), ) * (np.random.rand(dim_rv, 2)<1.0)
            # print("win", win.shape, " vars", np.var(win.T[0]), np.var(win.T[1]))
        
        maxtau=200
        states_series=DESN_observer(
            u_ex, dim_rv, 
            x_init=np.zeros(dim_rv), w = np.diag(lams), win = win,
            rho=1.0, w_norm=True,# False if lam_name=="gauss" else True,
            densin = 1.0, density = None,
            activation = activation,)

        rcond=1e-15
        # print(np.var(u), np.var(v))
        MC_num, MFs_num, ranks_num = MC_0ex_inv(
            dim_rv, states_series.T, u, maxtau=maxtau, Two=Two, T=T, thresh='N', rcond=rcond,
            debug=False,
            )
        axid.plot(
            MFs_num,
            # label="num", #+ noisename,
            lw=5
            )

# Experiments that do not work

In [27]:
import torch, wandb
import numpy as np
np.set_printoptions(threshold=np.inf, precision=10, linewidth=10**4)
torch.set_printoptions(threshold=np.inf, precision=10, linewidth=10**4)
api = wandb.Api()
print(api)

project_name = 'hippo'
run_id = wandb.run.id  # 現在のランIDを取得
print(run_id)

runs = api.runs(project_name)
print(runs)

for idx, run in enumerate(runs):
    print(idx, run.id)
    print(run)
    metrics = run.history()
    # for rep, met in enumerate(metrics):
    #     print(met)
    
    artifacts = run.logged_artifacts()  # ランに関連する全てのアーティファクト
    for artifact in artifacts:
        print(artifact)
        if 'model' in artifact.name.lower():  # モデルアーティファクトを特定する
            print(f"Artifact Name: {artifact.name}")
            print(f"Artifact Type: {artifact.type}")
            print(f"Artifact Metadata: {artifact.metadata}")  # メタデータにモデルに関する情報が含まれていることがあります
    if idx==2:
        break

# run = api.run("ユーザー名/プロジェクト名/ランID")
# artifact = run.use_artifact("モデル名:latest") #（例：last.ckpt）
# artifact_dir = artifact.download()

# model = MyModel.load_from_checkpoint(artifact_dir + "/last.ckpt") # モデル読み込み（Lightningなど使ってるなら）

<wandb.apis.public.api.Api object at 0x7fa2f0e0b3d0>
x2hzde8t
<Runs jingchuan0guan-the-university-of-tokyo-hospital/hippo>
0 bxdoprfi
<Run jingchuan0guan-the-university-of-tokyo-hospital/hippo/bxdoprfi (finished)>
<Artifact QXJ0aWZhY3Q6MTU4MjYzMzQyOQ==>
1 wxox9i81
<Run jingchuan0guan-the-university-of-tokyo-hospital/hippo/wxox9i81 (finished)>
<Artifact QXJ0aWZhY3Q6MTY0NTIxMTYzMQ==>
2 u4c4hkof
<Run jingchuan0guan-the-university-of-tokyo-hospital/hippo/u4c4hkof (finished)>
<Artifact QXJ0aWZhY3Q6MTY0OTMyOTU1MQ==>


In [None]:
# @hydra.main(config_path="./configs/model/", config_name="s4d", version_base=None)
# def get_model(model_config: DictConfig):
#     print(model_config)  # DictConfigとしてアクセス可能
#     for key in model_config.keys():
#         print(key, model_config[key])
#     model = utils.instantiate(registry.model, model_config)#.load_from_checkpoint("path/to/model.ckpt")
#     print(model)
    
#     state_dict = model.state_dict()
#     for k, v in state_dict.items():
#         print(k, v.shape)

# import sys
# sys.path.append('/home/kan/ML_application/s4/src')
# layer_key = ["s4d", ][0]
# print(registry.model["model"])
# print(registry.layer[layer_key])
# print(registry.model["model"](layer=registry.layer[layer_key]))
# ckpt ファイルをロード
# model = utils.instantiate(registry.model, layer="s4d", )
# registry.model["model"](layer=layer).load_from_checkpoint("path/to/model.ckpt")

# for root, dirs, files in os.walk(base_dir):
#     # 今の階層の深さを数える
#     depth = root[len(base_dir):].count(os.sep)
#     print(depth)
#     if depth > max_depth:
#         # さらに深い階層の探索をやめる
#         dirs[:] = []
#         continue

#     for file in files:
#         print(os.path.join(root, file))

In [None]:
from pytorch_lightning import LightningModule
# import src.utils as utils
from pathlib import Path

import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
# from hydra.utils import instantiate

from src.utils import registry, instantiate

def check_model(model_name="model", cfg_name="config", cfg_path="./configs/model/",  model_path="", ):
    with initialize(config_path=cfg_path):
        cfg = compose(config_name=cfg_name)
    print("cfg", cfg)
    OmegaConf.set_struct(cfg, False)
    # cfg.model._target_ = "torch.nn.Identity" 
    ModelClass = hydra.utils.get_class(registry.model[model_name])
    # model = ModelClass(cfg)
    
    # model = instantiate(cfg.model) # registry.model,
    model = instantiate(registry.model, cfg) # モデル構造のインスタンス化（この時点でランダム初期化）
    print(model)
    print("type", type(model))
    ckpt = torch.load(model_path, map_location="cpu")# チェックポイントからstate_dictを読み込み
    # state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
    # model.load_state_dict(state_dict)
    print("Checkpoint keys:", ckpt.keys())  # 追加
    
    state_dict = ckpt.get("state_dict", ckpt)
    model.load_state_dict(state_dict)
    # if "state_dict" in ckpt:
    #     state_dict = ckpt["state_dict"]
    # else:
    #     state_dict = ckpt
    # model.load_state_dict(state_dict)
    
    # for key in cfg.keys():
    #     print(key, cfg[key])
    # model = utils.instantiate(registry.model, cfg).load_from_checkpoint(model_path)

    print(model)
    
# /home/kan/ML_application/s4/outputs/2025-04-21/11-19-18/checkpoints/

base_dir = "/home/kan/ML_application/s4/outputs/"
max_depth = 2

base_dir_ = Path(base_dir)
for file in base_dir_.glob("*/*/checkpoints/*"):
    if file.is_file():
        parent_path = file.parents[1]
        print(parent_path)
        print(file)
        relative_path = parent_path.parts[parent_path.parts.index("outputs"):]  # s4 以降の部分を抽出
        cfg_path = "/".join(relative_path)+ "/.hydra"
        print(cfg_path)
        check_model(model_name="model", cfg_path=cfg_path, model_path=file)


/home/kan/ML_application/s4/outputs/2025-04-20/22-59-56
/home/kan/ML_application/s4/outputs/2025-04-20/22-59-56/checkpoints/last.ckpt
outputs/2025-04-20/22-59-56/.hydra


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path=cfg_path):


InstantiationException: Cannot instantiate config of type type.
Top level config must be an OmegaConf DictConfig/ListConfig object,
a plain dict/list, or a Structured Config class or instance.