# クープマン作用素で地盤変動の時系列予測

## クープマンラボのチュートリアルデータ

## クープマンラボの中身を抽出して実行

### 畳み込み処理の定義

In [7]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

# The structure of Auto-Encoder
class encoder_mlp(nn.Module):
    def __init__(self, t_len, op_size):
        super(encoder_mlp, self).__init__()
        self.layer = nn.Linear(t_len, op_size)
    def forward(self, x):
        x = self.layer(x)
        return x

class decoder_mlp(nn.Module):
    def __init__(self, t_len, op_size):
        super(decoder_mlp, self).__init__()
        self.layer = nn.Linear(op_size, t_len)
    def forward(self, x):
        x = self.layer(x)
        return x

class encoder_conv1d(nn.Module):
    def __init__(self, t_len, op_size):
        super(encoder_conv1d, self).__init__()
        self.layer = nn.Conv1d(t_len, op_size,1)
    def forward(self, x):
        x = x.permute([0,2,1])
        x = self.layer(x)
        x = x.permute([0,2,1])
        return x

class decoder_conv1d(nn.Module):
    def __init__(self, t_len, op_size):
        super(decoder_conv1d, self).__init__()
        self.layer = nn.Conv1d(op_size, t_len,1)
    def forward(self, x):
        x = x.permute([0,2,1])
        x = self.layer(x)
        x = x.permute([0,2,1])
        return x

class encoder_conv2d(nn.Module):
    def __init__(self, t_len, op_size):
        super(encoder_conv2d, self).__init__()
        self.layer = nn.Conv2d(t_len, op_size,1)
    def forward(self, x):
        x = x.permute([0,3,1,2])
        x = self.layer(x)
        x = x.permute([0,2,3,1])
        return x

class decoder_conv2d(nn.Module):
    def __init__(self, t_len, op_size):
        super(decoder_conv2d, self).__init__()
        self.layer = nn.Conv2d(op_size, t_len,1)
    def forward(self, x):
        x = x.permute([0,3,1,2])
        x = self.layer(x)
        x = x.permute([0,2,3,1])
        return x

# Koopman 1D structure
class Koopman_Operator1D(nn.Module):
    def __init__(self, op_size, modes_x = 16):
        super(Koopman_Operator1D, self).__init__()
        self.op_size = op_size
        self.scale = (1 / (op_size * op_size))
        self.modes_x = modes_x
        self.koopman_matrix = nn.Parameter(self.scale * torch.rand(op_size, op_size, self.modes_x, dtype=torch.cfloat))
    # Complex multiplication
    def time_marching(self, input, weights):
        # (batch, t, x), (t, t+1, x) -> (batch, t+1, x)
        return torch.einsum("btx,tfx->bfx", input, weights)
    def forward(self, x):
        batchsize = x.shape[0]
        # Fourier Transform
        x_ft = torch.fft.rfft(x)
        # Koopman Operator Time Marching
        out_ft = torch.zeros(x_ft.shape, dtype=torch.cfloat, device = x.device)
        out_ft[:, :, :self.modes_x] = self.time_marching(x_ft[:, :, :self.modes_x], self.koopman_matrix)
        #Inverse Fourier Transform
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x

class KNO1d(nn.Module):
    def __init__(self, encoder, decoder, op_size, modes_x = 16, decompose = 4, linear_type = True, normalization = False):
        super(KNO1d, self).__init__()
        # Parameter
        self.op_size = op_size
        self.decompose = decompose
        # Layer Structure
        self.enc = encoder
        self.dec = decoder
        self.koopman_layer = Koopman_Operator1D(self.op_size, modes_x = modes_x)
        self.w0 = nn.Conv1d(op_size, op_size, 1)
        self.linear_type = linear_type # If this variable is False, activate function will be worked after Koopman Matrix
        self.normalization = normalization
        if self.normalization:
            self.norm_layer = torch.nn.BatchNorm2d(op_size)
    def forward(self, x):
        # Reconstruct
        x_reconstruct = self.enc(x)
        x_reconstruct = torch.tanh(x_reconstruct)
        x_reconstruct = self.dec(x_reconstruct)
        # Predict
        x = self.enc(x) # Encoder
        x = torch.tanh(x)
        x = x.permute(0, 2, 1)
        x_w = x
        for i in range(self.decompose):
            x1 = self.koopman_layer(x) # Koopman Operator
            if self.linear_type:
                x = x + x1
            else:
                x = torch.tanh(x + x1)
        if self.normalization:
            x = torch.tanh(self.norm_layer(self.w0(x_w)) + x)
        else:
            x = torch.tanh(self.w0(x_w) + x)
        x = x.permute(0, 2, 1)
        x = self.dec(x) # Decoder
        return x, x_reconstruct

# Koopman 2D structure
class Koopman_Operator2D(nn.Module):
    def __init__(self, op_size, modes_x, modes_y):
        super(Koopman_Operator2D, self).__init__()
        self.op_size = op_size
        self.scale = (1 / (op_size * op_size))
        self.modes_x = modes_x
        self.modes_y = modes_y
        self.koopman_matrix = nn.Parameter(self.scale * torch.rand(op_size, op_size, self.modes_x, self.modes_y, dtype=torch.cfloat))

    # Complex multiplication
    def time_marching(self, input, weights):
        # (batch, t, x,y ), (t, t+1, x,y) -> (batch, t+1, x,y)
        return torch.einsum("btxy,tfxy->bfxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Fourier Transform
        x_ft = torch.fft.rfft2(x)
        # Koopman Operator Time Marching
        out_ft = torch.zeros(x_ft.shape, dtype=torch.cfloat, device = x.device)
        out_ft[:, :, :self.modes_x, :self.modes_y] = self.time_marching(x_ft[:, :, :self.modes_x, :self.modes_y], self.koopman_matrix)
        out_ft[:, :, -self.modes_x:, :self.modes_y] = self.time_marching(x_ft[:, :, -self.modes_x:, :self.modes_y], self.koopman_matrix)
        #Inverse Fourier Transform
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x

class KNO2d(nn.Module):
    def __init__(self, encoder, decoder, op_size, modes_x = 10, modes_y = 10, decompose = 6, linear_type = True, normalization = False):
        super(KNO2d, self).__init__()
        # Parameter
        self.op_size = op_size
        self.decompose = decompose
        self.modes_x = modes_x
        self.modes_y = modes_y
        # Layer Structure
        self.enc = encoder
        self.dec = decoder
        self.koopman_layer = Koopman_Operator2D(self.op_size, self.modes_x, self.modes_y)
        self.w0 = nn.Conv2d(op_size, op_size, 1)
        self.linear_type = linear_type # If this variable is False, activate function will be worked after Koopman Matrix
        self.normalization = normalization
        if self.normalization:
            self.norm_layer = torch.nn.BatchNorm2d(op_size)
    def forward(self, x):
        # Reconstruct
        x_reconstruct = self.enc(x)
        x_reconstruct = torch.tanh(x_reconstruct)
        x_reconstruct = self.dec(x_reconstruct)
        # x_reconstructは入力データを再構成した結果
        # Predict
        x = self.enc(x) # Encoder
        x = torch.tanh(x)
        x = x.permute(0, 3, 1, 2)
        x_w = x
        for i in range(self.decompose):
            x1 = self.koopman_layer(x) # Koopman Operator
            if self.linear_type:
                x = x + x1
            else:
                x = torch.tanh(x + x1)
        if self.normalization:
            x = torch.tanh(self.norm_layer(self.w0(x_w)) + x)
        else:
            x = torch.tanh(self.w0(x_w) + x)
        x = x.permute(0, 2, 3, 1)
        x = self.dec(x) # Decoder
        return x, x_reconstruct

### 最適化の定義

In [8]:
import math
import torch
from torch import Tensor
from typing import List, Optional
from torch.optim.optimizer import Optimizer
import operator
from functools import reduce

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul,
                    list(p.size()+(2,) if p.is_complex() else p.size()))
    return c


def adam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[int],
         *,
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float):
    r"""Functional API that performs Adam algorithm computation.
    See :class:`~torch.optim.Adam` for details.
    """

    for i, param in enumerate(params):

        grad = grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        bias_correction1 = 1 - beta1 ** step
        bias_correction2 = 1 - beta2 ** step

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
        else:
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)

        step_size = lr / bias_correction1

        param.addcdiv_(exp_avg, denom, value=-step_size)


class Adam(Optimizer):
    r"""Implements Adam algorithm.
    It has been proposed in `Adam: A Method for Stochastic Optimization`_.
    The implementation of the L2 penalty follows changes proposed in
    `Decoupled Weight Decay Regularization`_.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)
    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super(Adam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(Adam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = 0
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    # update the steps for each param group update
                    state['step'] += 1
                    # record the step after step update
                    state_steps.append(state['step'])

            adam(params_with_grad,
                 grads,
                 exp_avgs,
                 exp_avg_sqs,
                 max_exp_avg_sqs,
                 state_steps,
                 amsgrad=group['amsgrad'],
                 beta1=beta1,
                 beta2=beta2,
                 lr=group['lr'],
                 weight_decay=group['weight_decay'],
                 eps=group['eps'])

### ViTの中身

In [9]:
import math
from functools import partial
from collections import OrderedDict
from copy import Error, deepcopy
from re import S
from numpy.lib.arraypad import pad
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
import torch.fft
from torch.nn.modules.container import Sequential
from torch.utils.checkpoint import checkpoint_sequential
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class PatchEmbed(nn.Module):
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# Use Fourier-Transformer structure to approximate linear Koopman Operator
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        input_size=(4, 14, 14),
    ):
        super().__init__()
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        assert attn_drop == 0.0  # do not use
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.input_size = input_size
        assert input_size[1] == input_size[2]

    def forward(self, x):
        B, N, C = x.shape
        q = (
            self.q(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        k = (
            self.k(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )
        v = (
            self.v(x)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
        )

        attn = (q @ k.transpose(-2, -1)) * self.scale

        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        x = x.view(B, -1, C)
        return x


class At_Block(nn.Module):
    """
    Transformer Block in Fourier Domain
    """
    def __init__(
        self,
        dim=768,
        num_heads=8,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        attn_func=Attention,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = attn_func(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x_ft = torch.fft.fft(x, dim=-1, norm="ortho")
        x_r = x_ft.real
        x_i = x_ft.imag
        # Calculate Real Part
        x_r = x_r + self.drop_path(self.attn(self.norm1(x_r)))
        x_r = x_r + self.drop_path(self.mlp(self.norm2(x_r)))
        # Calculate Imaginary Part
        x_i = x_i + self.drop_path(self.attn(self.norm1(x_i)))
        x_i = x_i + self.drop_path(self.mlp(self.norm2(x_i)))
        # Merge
        x_ft.real = x_r
        x_ft.imag = x_i
        x = torch.fft.ifft(x_ft, dim=-1, norm="ortho")
        return x

# Use linear AFNO1D structure to approximate linear Koopman Operator
class AFNO1D(nn.Module):
    def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
        super().__init__()
        assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"

        self.hidden_size = hidden_size
        self.sparsity_threshold = sparsity_threshold
        self.num_blocks = num_blocks
        self.block_size = self.hidden_size // self.num_blocks
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.hidden_size_factor = hidden_size_factor
        self.scale = 0.02

        self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))

    def forward(self, x):
        bias = x

        dtype = x.dtype
        x = x.float()
        B, N, C = x.shape

        x = torch.fft.rfft(x, dim=1, norm="ortho")
        x = x.reshape(B, N // 2 + 1, self.num_blocks, self.block_size)

        o1_real = torch.zeros([B, N // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o1_imag = torch.zeros([B, N // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)

        total_modes = N // 2 + 1
        kept_modes = int(total_modes * self.hard_thresholding_fraction)

        o1_real[:, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].real, self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].imag, self.w1[1]) + \
            self.b1[0]
        )

        o1_imag[:, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].imag, self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, :kept_modes].real, self.w1[1]) + \
            self.b1[1]
        )

        o2_real[:, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes], self.w2[0]) - \
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes], self.w2[1]) + \
            self.b2[0]
        )

        o2_imag[:, :kept_modes] = (
            torch.einsum('...bi,bio->...bo', o1_imag[:, :kept_modes], self.w2[0]) + \
            torch.einsum('...bi,bio->...bo', o1_real[:, :kept_modes], self.w2[1]) + \
            self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)
        x = x.reshape(B, N // 2 + 1, C)
        x = torch.fft.irfft(x, n=N, dim=1, norm="ortho")
        x = x.type(dtype)
        return x + bias

class Af_Block(nn.Module):
    """
    AdaptiveFNO Block
    """
    def __init__(
            self,
            dim,
            mlp_ratio=4.,
            drop=0.,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            double_skip=True,
            num_blocks=8,
            sparsity_threshold=0.01,
            hard_thresholding_fraction=1.0,
        ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.filter = AFNO1D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        #self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.double_skip = double_skip

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.filter(x)

        if self.double_skip:
            x = x + residual
            residual = x

        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x + residual
        return x


class ViT(nn.Module):
    def __init__(
            self,
            img_size=(720, 1440),
            patch_size=(8, 8),
            in_chans=20,
            out_chans=20,
            embed_dim=768,
            encoder_depth = 2,
            depth=10,
            mlp_ratio=4.,
            drop_rate=0.,
            drop_path_rate=0.,
            num_blocks=16,
            sparsity_threshold=0.01,
            hard_thresholding_fraction=1.0,
            settings = "Conv2d",
            encoder_network = False
        ):
        super().__init__()
        self.encoder_network = encoder_network
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        
        self.num_features = self.embed_dim = embed_dim
        self.num_blocks = num_blocks 
        self.depth = depth
        self.settings = settings
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]
        self.dpr = dpr
        
        self.h = img_size[0] // self.patch_size[0]
        self.w = img_size[1] // self.patch_size[1]
        
        
        # Encoder Settings
        self.encoder_depth = encoder_depth
        # There are two options. Af_Block represents using the AdaptiveFNO blocks, and At_Block represents using the Fourier-Transformer blocks.
        self.encoder_blocks = nn.ModuleList([
            Af_Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
            num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction)
            for i in range(encoder_depth)])
        
        # Koopman Layers
        self.core_blocks = nn.ModuleList([
            Af_Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
            num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction) 
        for i in range(self.depth)])

        self.norm = norm_layer(embed_dim)
        
        # High-frequency component
        self.w0 = nn.Conv1d(embed_dim, embed_dim, 1) # or user-defined more complicated convolutional structure
        
        # Decoder Settings
        if self.settings == "MLP":
            self.decoder_pred_mlp = nn.Linear(self.embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False)
        elif self.settings == "Conv2d":
            self.decoder_pred_conv2d = nn.ConvTranspose2d(self.embed_dim, self.out_chans, kernel_size=self.patch_size, stride=self.patch_size)
        
        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}
    
    def encoder(self, x):
        # Position Encoder
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        # Encoder Network (if reconstruction task is hard, please use more complicated structure)
        for blk in self.encoder_blocks:
            x = blk(x)
            
        if self.encoder_network:
            x = self.encoder_network(x)
        
        return x
    
    def decoder(self, x):
        B = x.shape[0]
        x = x.reshape(B, self.h, self.w, self.embed_dim)
        if self.settings == "MLP":
            x = self.decoder_pred_mlp(x)
            x = rearrange(
                x,
                "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
                p1=self.patch_size[0],
                p2=self.patch_size[1],
                h=self.img_size[0] // self.patch_size[0],
                w=self.img_size[1] // self.patch_size[1],
            )
        elif self.settings == "Conv2d":
            x = rearrange(x, "B H W C -> B C H W")
            x = self.decoder_pred_conv2d(x)
        return x
    
    def forward(self, x):
        x = self.encoder(x)
        # Reconstruction
        x_recons = self.decoder(x)
        # Prediction
        x_w = self.w0(x.permute(0,2,1)).permute(0,2,1)
        for blk in self.core_blocks:
            x = blk(x)
        x = blk(x)
        x = x + x_w
        x = self.decoder(x)
        return x, x_recons

In [10]:
class koopman_vit:
    def __init__(self, decoder = "Conv2d", depth = 16, resolution=(256, 256), patch_size=(4, 4),
            in_chans=1, out_chans=1, embed_dim=768, parallel = False, device = False):
        # Model Hyper-parameters
        self.decoder = decoder
        self.resolution = resolution
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_blocks = 16
        # Core Model
        self.params = 0
        self.kernel = False
        # Opt Setting
        self.optimizer = False
        self.scheduler = False
        self.device = device
        self.parallel = parallel
        self.loss = torch.nn.MSELoss()
    def compile(self):
        self.kernel = ViT(img_size=self.resolution, patch_size=self.patch_size, in_chans=self.in_chans, out_chans=self.out_chans, num_blocks=self.num_blocks, embed_dim = self.embed_dim, depth=self.depth, settings = self.decoder).to(self.device)
        if self.parallel:
            self.kernel = torch.nn.DataParallel(self.kernel)
        self.params = utils.count_params(self.kernel)
        
        print("Koopman Fourier Vision Transformer has been compiled!")
        print("The Model Parameters Number is ",self.params)
        
    def opt_init(self, opt, lr, step_size, gamma):
        if opt == "Adam":
            self.optimizer = utils.Adam(self.kernel.parameters(), lr= lr, weight_decay=1e-4)
        if not step_size == False:
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
        
    def train_multi(self, epochs, trainloader, T_out = 10, evalloader = False):
        T_eval = T_out
        for ep in range(epochs):
            self.kernel.train()
            t1 = default_timer()
            train_recons_full = 0
            train_pred_full = 0
            for xx, yy in trainloader:
                l_recons = 0
                xx = xx.to(self.device) # [batchsize,1,x,y]
                # print(xx.size())
                yy = yy.to(self.device) # [batchsize,T,x,y]
                bs = xx.shape[0]
                for t in range(0, T_out):
                    y = yy[:, t:t + 1]
                    im,im_re = self.kernel(xx)
                    # print(im.size())
                    # print(im_re.size())
                    # print(xx.reshape(bs, -1).size())
                    # print(im_re.reshape(bs, -1).size())
                    l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                    
                    if t == 0:
                        pred = im[:, -1:]
                    else:
                        pred = torch.cat((pred, im[:, -1:]), -1)
                    
                    xx = im
                
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))
                loss = 5 * l_pred + 0.5 * l_recons
                
                train_pred_full += l_pred.item()
                train_recons_full += l_recons.item()/T_out

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_pred_full = train_pred_full / len(trainloader)
            train_recons_full = train_recons_full / len(trainloader)
            t2 = default_timer()
            test_pred_full = 0
            test_recons_full = 0
            loc = 0
            mse_error = 0
            if evalloader:
                with torch.no_grad():
                    for xx, yy in evalloader:
                        loss = 0
                        xx = xx.to(self.device)
                        yy = yy.to(self.device)

                        for t in range(0, T_eval):
                            y = yy[:, t:t + 1]
                            im, im_re = self.kernel(xx)
                            
                            l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                            
                            if t == 0:
                                pred = im
                            else:
                                pred = torch.cat((pred, im), 1)
                                
                            xx = im
                            
                        l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))

                        test_recons_full += l_recons.item() / T_eval
                        test_pred_full += l_pred.item()
                        
                test_recons_full = test_recons_full / len(evalloader)
                test_pred_full = test_pred_full / len(evalloader)
            self.scheduler.step()

            if evalloader:
                if ep == 0:
                    print("Epoch","Time","[Train Recons MSE]","[Train Pred MSE]","[Eval Recons MSE]","[Eval Pred MSE]")
                print(ep, t2 - t1, train_recons_full, train_pred_full, test_recons_full, test_pred_full)
            else:
                if ep == 0:
                    print("Epoch","Time","Train Recons MSE","Train Pred MSE")
                print(ep, t2 - t1, train_recons_full, train_pred_full)
    
    def test_multi(self, testloader, step = 1, T_out = 5, path = False, is_save = False, is_plot = False):
        time_error = torch.zeros([T_out,1])
        test_pred_full = 0
        test_recons_full = 0
        loc = 0
        with torch.no_grad():
            for xx, yy in testloader:
                loss = 0
                bs = xx.shape[0]
                xx = xx.to(self.device)
                yy = yy.to(self.device)
                l_recons = 0
                for t in range(0, T_out):
                    y = yy[:, t:t + 1]
                    im, im_re = self.kernel(xx)
                    
                    
                    l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                    t_error = self.loss(im, y)
                    
                    xx = im
                    
                    if t == 0:
                        pred = im
                    else:
                        pred = torch.cat((pred, im), 1)
                    time_error[t] = time_error[t] + t_error.item()
    
                test_recons_full += l_recons.item() / T_out
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))
                test_pred_full += l_pred.item()

                if(loc == 0 & is_save):
                    torch.save({"pred":pred, "yy":yy}, path+ "pred_yy.pt")
                
                if(loc == 0 & is_plot):
                    for i in range(T_out):
                        plt.subplot(1,3,1)
                        plt.title("Predict")
                        plt.imshow(pred[0,i].cpu().detach().numpy())
                        plt.subplot(1,3,2)
                        plt.imshow(yy[0,i].cpu().detach().numpy())
                        plt.title("Label")
                        plt.subplot(1,3,3)
                        plt.imshow(pred[0,i].cpu().detach().numpy()-yy[0,i].cpu().detach().numpy())
                        plt.title("Error")
                        plt.show()
                        plt.savefig(path + "time_"+str(i)+".png")
                        plt.close()

                loc = loc + 1
        test_pred_full = test_pred_full / loc
        test_recons_full = test_recons_full / loc
        time_error = time_error / len(testloader)
        print("Total prediction test mse error is ",test_pred_full)
        print("Total reconstruction test mse error is ",test_recons_full)
        return time_error
        
        
    def train_single(self, epochs, trainloader, evalloader = False):
        for ep in range(epochs):
            self.kernel.train()
            t1 = default_timer()
            train_recons_full = 0
            train_pred_full = 0
            for x, y in trainloader:
                l_recons = 0
                x = x.to(self.device) # [batchsize,1,64,64]
                y = y.to(self.device) # [batchsize,1,64,64]
                bs = x.shape[0]
                
                im,im_re = self.kernel(x)
                
                l_recons = self.loss(im_re.reshape(bs, -1), x.reshape(bs, -1))
                l_pred = self.loss(im.reshape(bs, -1), y.reshape(bs, -1))
                
                loss = 5 * l_pred + 0.5 * l_recons
                
                train_pred_full += l_pred.item()
                train_recons_full += l_recons.item()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_pred_full = train_pred_full / len(trainloader)
            train_recons_full = train_recons_full / len(trainloader)
            t2 = default_timer()
            test_pred_full = 0
            test_recons_full = 0
            loc = 0
            mse_error = 0
            if evalloader:
                with torch.no_grad():
                    for x, y in evalloader:
                        loss = 0
                        x = x.to(self.device)
                        y = y.to(self.device)
                        
                        im, im_re = self.kernel(x)

                        l_recons = self.loss(im_re.reshape(bs, -1), x.reshape(bs, -1))
                        l_pred = self.loss(im.reshape(bs, -1), y.reshape(bs, -1))

                        test_recons_full += l_recons.item()
                        test_pred_full += l_pred.item()
                        
                test_recons_full = test_recons_full / len(evalloader)
                test_pred_full = test_pred_full / len(evalloader)
            self.scheduler.step()

            if evalloader:
                if ep == 0:
                    print("Epoch","Time","[Train Recons MSE]","[Train Pred MSE]","[Eval Recons MSE]","[Eval Pred MSE]")
                print(ep, t2 - t1, train_recons_full, train_pred_full, test_recons_full, test_pred_full)
            else:
                if ep == 0:
                    print("Epoch","Time","Train Recons MSE","Train Pred MSE")
                print(ep, t2 - t1, train_recons_full, train_pred_full)
                
    def test_single(self, testloader, T_out = 1, path = False, is_save = False, is_plot = False):
        time_error = torch.zeros([T_out,1])
        test_pred_full = 0
        test_recons_full = 0
        loc = 0
        idx = np.random.randint(0,len(testloader))
        with torch.no_grad():
            for xx, yy in testloader:
                loss = 0
                bs = xx.shape[0]
                xx = xx.to(self.device)
                yy = yy.to(self.device)
                l_recons = 0
                for t in range(0, T_out):
                    y = yy[:, t:t + 1]
                    im, im_re = self.kernel(xx)
                    
                    l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                    t_error = self.loss(im, y)
                    
                    xx = im
                    
                    if t == 0:
                        pred = im
                    else:
                        pred = torch.cat((pred, im), 1)
                    time_error[t] = time_error[t] + t_error.item()
    
                test_recons_full += l_recons.item() / T_out
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))
                test_pred_full += l_pred.item()

                if(loc == 0 & is_save):
                    torch.save({"pred":pred, "yy":yy}, path+ "pred_yy.pt")
                
                if(loc == 0 & is_plot):
                    for i in range(T_out):
                        plt.subplot(1,3,1)
                        plt.title("Predict")
                        plt.imshow(pred[0,i].cpu().detach().numpy())
                        plt.subplot(1,3,2)
                        plt.imshow(yy[0,i].cpu().detach().numpy())
                        plt.title("Label")
                        plt.subplot(1,3,3)
                        plt.imshow(pred[0,i].cpu().detach().numpy()-yy[0,i].cpu().detach().numpy())
                        plt.title("Error")
                        plt.show()
                        plt.savefig(path + "time_"+str(i)+".png")
                        plt.close()
                loc = loc + 1

        test_pred_full = test_pred_full / len(testloader)
        test_recons_full = test_recons_full / len(testloader)
        time_error = time_error / len(testloader)
        print("Total prediction test mse error is ",test_pred_full)
        print("Total reconstruction test mse error is ",test_recons_full)
        
        return time_error
        
    def save(self, path):
#        (fpath,_) = os.path.split(path)
#        print(fpath, os.path.isfile(fpath))
#        if not os.path.isfile(fpath):
#            os.makedirs(fpath)
        torch.save({"koopman":self,"model":self.kernel,"model_params":self.kernel.state_dict()}, path)

### CNNの学習クラス

In [24]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from timeit import default_timer
# ep = 1000 # Training Epoch
# o = 32 # Koopman Operator Size
# m = 16 # Modes
# r = 8 # Power of Koopman Matrix
class koopman:
    def __init__(self, backbone = "KNO2d", autoencoder = "Conv2d", o = 16, m = 16, r = 8, t_in = 1, device = False):
        self.backbone = backbone
        self.autoencoder = autoencoder
        self.operator_size = o
        self.modes = m
        self.decompose = r
        self.device = device
        self.t_in = t_in
        # Core Model
        self.params = 0
        self.kernel = False
        # Opt Setting
        self.optimizer = False
        self.scheduler = False
        self.loss = torch.nn.MSELoss()
    def compile(self):
        if self.autoencoder == "MLP":
            encoder = encoder_mlp(self.t_in, self.operator_size)
            decoder = decoder_mlp(self.t_in, self.operator_size)
            print("The autoencoder type is MLP.")
        elif self.autoencoder == "Conv1d":
            encoder = encoder_conv1d(self.t_in, self.operator_size)
            decoder = decoder_conv1d(self.t_in, self.operator_size)
            print("The autoencoder type is Conv1d.")
        elif self.autoencoder == "Conv2d":
            encoder = encoder_conv2d(self.t_in, self.operator_size)
            decoder = decoder_conv2d(self.t_in, self.operator_size)
            print("The autoencoder type is Conv2d.")
        else:
#            encoder = kno.encoder_mlp(self.t_in, self.operator_size)
#            decoder = kno.decoder_mlp(self.t_in, self.operator_size)
#            print("The autoencoder type is MLP.")
            print("Wrong!")
        if self.backbone == "KNO1d":
            self.kernel = KNO1d(encoder, decoder, self.operator_size, modes_x = self.modes, decompose = self.decompose).to(self.device)
            print("KNO1d model is completed.")
        
        elif self.backbone == "KNO2d":
            self.kernel = KNO2d(encoder, decoder, self.operator_size, modes_x = self.modes, modes_y = self.modes,decompose = self.decompose).to(self.device)
            print("KNO2d model is completed.")
        elif self.backbone == "DMD":
            self.kernel = KNO2d_DMD(encoder, decoder, self.operator_size, decompose = self.decompose).to(self.device)
            print("KNO2d_DMD model is completed.")
        elif self.backbone == "SSM":
            self.kernel = SSM2d(encoder, decoder, self.operator_size, decompose = self.decompose, hidden_size=self.operator_size ).to(self.device)
            print("SSM2d model is completed.")
        if not self.kernel == False:
            self.params = count_params(self.kernel)
            print("Koopman Model has been compiled!")
            print("The Model Parameters Number is ",self.params)
    def opt_init(self, opt, lr, step_size, gamma):
        if opt == "Adam":
            self.optimizer = Adam(self.kernel.parameters(), lr= lr, weight_decay=1e-4)
        if not step_size == False:
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)

    def train_single(self, epochs, trainloader, evalloader = False):
        for ep in range(epochs):
            # Train
            self.kernel.train()
            t1 = default_timer()
            train_recons_full = 0
            train_pred_full = 0
            for xx, yy in trainloader:
                l_recons = 0
                bs = xx.shape[0]
                xx = xx.to(self.device)
                yy = yy.to(self.device)
                pred,im_re = self.kernel(xx)
                
                l_recons = self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))

                train_pred_full += l_pred.item()
                train_recons_full += l_recons.item()

                loss = 5*l_pred + 0.5*l_recons
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_pred_full = train_pred_full / len(trainloader)
            train_recons_full = train_recons_full / len(trainloader)
            t2 = default_timer()
            test_pred_full = 0
            test_recons_full = 0
            mse_test = 0
            # Test
            if evalloader:
                with torch.no_grad():
                    for xx, yy in evalloader:
                        bs = xx.shape[0]
                        loss = 0
                        xx = xx.to(self.device)
                        yy = yy.to(self.device)

                        pred,im_re = self.kernel(xx)


                        l_recons = self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                        l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))


                        test_pred_full += l_pred.item()
                        test_recons_full += l_recons.item()
                        
                test_pred_full = test_pred_full/len(evalloader)
                test_recons_full = test_recons_full/len(evalloader)
                
            self.scheduler.step()

            if evalloader:
                if ep == 0:
                    print("Epoch","Time","[Train Recons MSE]","[Train Pred MSE]","[Eval Recons MSE]","[Eval Pred MSE]")
                print(ep, t2 - t1, train_recons_full, train_pred_full, test_recons_full, test_pred_full)
            else:
                if ep == 0:
                    print("Epoch","Time","Train Recons MSE","Train Pred MSE")
                print(ep, t2 - t1, train_recons_full, train_pred_full)

    def test_single(self, testloader):
        test_pred_full = 0
        test_recons_full = 0
        with torch.no_grad():
            for xx, yy in testloader:
                bs = xx.shape[0]
                loss = 0
                xx = xx.to(self.device)
                yy = yy.to(self.device)

                pred,im_re = self.kernel(xx)

                l_recons = self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))

                test_pred_full += l_pred.item()
                test_recons_full += l_recons.item()
        test_pred_full = test_pred_full/len(testloader)
        test_recons_full = test_recons_full/len(testloader)
        print("Total prediction test mse error is ",test_pred_full)
        print("Total reconstruction test mse error is ",test_recons_full)
        return test_pred_full


    def train(self, epochs, trainloader, step = 1, T_out = 40, T_eval = 80, evalloader = False):
        for ep in range(epochs):
            self.kernel.train()
            t1 = default_timer()
            train_recons_full = 0
            train_pred_full = 0
            for xx, yy in trainloader:
                l_recons = 0
                xx = xx.to(self.device)
                # print(xx.size())
                yy = yy.to(self.device)
                bs = xx.shape[0]
                for t in range(0, T_out):
                    y = yy[..., t:t + 1]

                    im,im_re = self.kernel(xx)
                    # print(im.size()) # 予測
                    # print(im_re.size()) # 入力の再構成
                    l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                    if t == 0:
                        pred = im[...,-1:]
                    else:
                        pred = torch.cat((pred, im[...,-1:]), -1)
                    
                    xx = torch.cat((xx[..., step:], im[...,-1:]), dim=-1)

                
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))
                loss = 5 * l_pred + 0.5 * l_recons
                
                train_pred_full += l_pred.item()
                train_recons_full += l_recons.item()/T_out

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            train_pred_full = train_pred_full / len(trainloader)
            train_recons_full = train_recons_full / len(trainloader)
            t2 = default_timer()
            test_pred_full = 0
            test_recons_full = 0
            loc = 0
            mse_error = 0
            if evalloader:
                with torch.no_grad():
                    for xx, yy in evalloader:
                        loss = 0
                        xx = xx.to(self.device)
                        yy = yy.to(self.device)

                        for t in range(0, T_eval):
                            y = yy[..., t:t + 1]
                            im, im_re = self.kernel(xx)
                            l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                            if t == 0:
                                pred = im[...,-1:]
                            else:
                                pred = torch.cat((pred, im[...,-1:]), -1)
                            xx = torch.cat((xx[..., 1:], im[...,-1:]), dim=-1)
                        # print("pred: ")
                        # print(pred.size())
                        # print("yy: ")
                        # print(yy.size()) 
                        # l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))

                        test_recons_full += l_recons.item() / T_eval
                        test_pred_full += l_pred.item()
                        
                        loc = loc + 1
                    mse_error = mse_error / loc
                test_recons_full = test_recons_full / len(evalloader)
                test_pred_full = test_pred_full / len(evalloader)
            self.scheduler.step()

            if evalloader:
                if ep == 0:
                    print("Epoch","Time","[Train Recons MSE]","[Train Pred MSE]","[Eval Recons MSE]","[Eval Pred MSE]")
                print(ep, t2 - t1, train_recons_full, train_pred_full, test_recons_full, test_pred_full)
            else:
                if ep == 0:
                    print("Epoch","Time","Train Recons MSE","Train Pred MSE")
                print(ep, t2 - t1, train_recons_full, train_pred_full)
    def test(self, testloader, step = 1, T_out = 80, path = False, is_save = True, is_plot = False):
        time_error = torch.zeros([T_out,1])
        test_pred_full = 0
        test_recons_full = 0
        loc = 0
        with torch.no_grad():
            for xx, yy in testloader:
                loss = 0
                bs = xx.shape[0]
                xx = xx.to(self.device)
                yy = yy.to(self.device)
                l_recons = 0
                for t in range(0, T_out):
                    y = yy[..., t:t + 1]
                    im, im_re = self.kernel(xx)
                    l_recons += self.loss(im_re.reshape(bs, -1), xx.reshape(bs, -1))
                    t_error = self.loss(im[...,-1:],y)
                    if t == 0:
                        pred = im[...,-1:]
                    else:
                        pred = torch.cat((pred, im[...,-1:]), -1)
                    time_error[t] = time_error[t] + t_error.item()
                    xx = torch.cat((xx[..., 1:], im[...,-1:]), dim=-1)

                test_recons_full += l_recons.item() / T_out
                l_pred = self.loss(pred.reshape(bs, -1), yy.reshape(bs, -1))
                test_pred_full += l_pred.item()
                if(is_save):
                    torch.save({"pred": pred, "yy": yy}, f"{path}pred_yy_loc{loc}.pt")
                
                if(is_plot):
                    for i in range(T_out):
                        plt.subplot(1,3,1)
                        plt.title("Predict")
                        plt.imshow(pred[0,...,i].cpu().detach().numpy())
                        plt.subplot(1,3,2)
                        plt.imshow(yy[0,...,i].cpu().detach().numpy())
                        plt.title("Label")
                        plt.subplot(1,3,3)
                        plt.imshow(pred[0,...,i].cpu().detach().numpy()-yy[0,...,i].cpu().detach().numpy())
                        plt.title("Error")
                        plt.show()
                        plt.savefig(f"{path}time_{i}_loc{loc}.png")
                        plt.close()
                loc = loc + 1
        test_pred_full = test_pred_full / loc
        test_recons_full = test_recons_full / loc
        time_error = time_error / len(testloader)
        print("Total prediction test mse error is ",test_pred_full)
        print("Total reconstruction test mse error is ",test_recons_full)
        return time_error
        
    def save(self, path):
        (fpath,_) = os.path.split(path)
        if not os.path.isfile(fpath):
            os.makedirs(fpath)
        torch.save({"koopman":self,"model":self.kernel,"model_params":self.kernel.state_dict()}, path)



In [25]:
data_path = "./tutorial/NavierStokes_V1e-3_N5000_T50-013/ns_V1e-3_N5000_T50.mat"
train_loader, eval_loader = navier_stokes(data_path, batch_size = 1, T_in = 10, T_out = 40, type = "1e-3", sub = 1, reshape= True)

# イテレータを使って最初のバッチを取得
dataiter = iter(train_loader)
x, y = next(dataiter)

# データの形状を確認
print("Input x shape:", x.shape)
print("Target y shape:", y.shape)

# データの値の範囲も確認
print("\nInput x stats:")
print("Min value:", x.min().item())
print("Max value:", x.max().item())
print("Mean value:", x.mean().item())

# メモリ上の位置とデータ型も確認
print("\nDevice:", x.device)
print("Data type:", x.dtype)

# バッチの1枚目の画像のチャンネルごとの統計も確認
print("\nFirst image in batch stats:")
print("Channel-wise mean:", x[0].mean(dim=(1,2)))
print("Channel-wise std:", x[0].std(dim=(1,2)))


In [26]:
## Parameter definitions:
device = torch.device("cuda")
# Hyper parameters
ep = 1 # Training Epoch
o = 32 # Koopman Operator Size
m = 16 # Modes
r = 8 # Power of Koopman Matrix

# class koopman_vit:
#     def __init__(self, decoder = "Conv2d", depth = 16, resolution=(256, 256), patch_size=(4, 4),
#             in_chans=1, out_chans=1, embed_dim=768, parallel = False, device = False):
ViT_KNO = koopman_vit(decoder = "Conv2d", resolution=(64, 64), patch_size=(2, 2),
            in_chans=10, out_chans=10, embed_dim=768, depth = 16, parallel = True, device=device)
ViT_KNO.compile()
T_out = 40
# ViT_KNO.train_single(epochs=ep, trainloader = train_loader, evalloader = eval_loader)
# ViT_KNO.test_single(test_loader)
ViT_KNO.train_multi(epochs=ep, trainloader = train_loader, evalloader = eval_loader, T_out = T_out)
ViT_KNO.test_multi(test_loader)


## SBASデータローダ

In [27]:
import os
import h5py
import numpy as np
import cv2
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
def resize_with_cv2(disp_arr, target_size=(64, 64)):
    """OpenCVを使ってリサイズ"""
    h, w, c = disp_arr.shape
    resized_channels = [cv2.resize(disp_arr[:, :, i], target_size, interpolation=cv2.INTER_LINEAR) for i in range(c)]
    return np.stack(resized_channels, axis=-1)


# SBASデータのロード
def load_disp_data(parent_dir, target_size=(64, 64), target_channels=100):
    """フォルダから変位データを読み込み、128×128にリサイズ"""
    disp_data = []
    
    for subfolder in sorted(os.listdir(parent_dir)):
        subfolder_path = os.path.join(parent_dir, subfolder)
        
        if not os.path.isdir(subfolder_path):
            continue
            
        h5_file_path = os.path.join(subfolder_path, "cum_filt.h5")
        
        if not os.path.exists(h5_file_path):
            print(f"File not found: {h5_file_path}")
            continue
            
        print(f"Loading: {h5_file_path}")
        with h5py.File(h5_file_path, "r") as h5_file:
            disp_arr_chw = h5_file['cum']
            disp_arr = np.transpose(disp_arr_chw, (1, 2, 0))  # (H, W, C)
            
            h, w, c = disp_arr.shape
            print(f"Shape of disp_arr: {h} x {w} x {c}")
            
            # チャンネル数が100未満の場合はスキップ
            if c < target_channels:
                print(f"Skipping {subfolder} (Channels: {c} < {target_channels})")
                continue
            
            # チャンネル数を100に制限
            disp_arr = disp_arr[:, :, :target_channels]
            
            # NaNを0に置き換え
            disp_arr = np.nan_to_num(disp_arr, nan=0.0)
            
            # **リサイズ (H, W) → (128, 128)**
            disp_resized = resize_with_cv2(disp_arr)  # 線形補間
            
            print(f"Resized shape: {disp_resized.shape}")
            disp_data.append(disp_resized)

    
    
    return disp_data

parent_dir = "E:/2024/koopman/sbas/result"
disp_data = load_disp_data(parent_dir)
disp_data = np.array(disp_data)
# 予測する時間を定義
T_in=20 # インプットフレーム数
T_out=40 # 予測フレーム数
batch_size =4
# Traning data
# 最初のT_inフレームを入力
train_a = disp_data[:,:,:,:T_in]
# 最初のT_inからT_outフレームを予測
train_u = disp_data[:,:,:,T_in:T_out+T_in]
# Testing data
test_a = train_a
test_u = disp_data[:,:,:,T_in:T_out+T_in+T_out]
print("Shallow Water Equations Dataset has been loaded successfully!")
print("X train shape:", train_a.shape, "Y train shape:", train_u.shape)
print("X test shape:", test_a.shape, "Y test shape:", test_u.shape)

train_a = torch.from_numpy(train_a).float()  # float型に変換
train_u = torch.from_numpy(train_u).float()
test_a = torch.from_numpy(test_a).float()
test_u = torch.from_numpy(test_u).float()


train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

# データローダーからiteratorを取得
data_iter = iter(train_loader)
# 最初の1つを取得
batch = next(data_iter)
for tmp in batch:
    # [入力,予測対象]のpytorch配列のリスト
    # torch.Size([1, 64, 64, 10])
    # torch.Size([1, 64, 64, 40])
    # バッチサイズｘ縦ｘ横ｘチャンネルの次元
    print(tmp.size())

## CNNベースの学習

In [28]:
import torch
import os
import h5py
    
# Setting your computing device
torch.cuda.set_device(0)
device = torch.device("cuda")




fig_path = "./demo/fig/"
save_path = "./demo/result/"
os.makedirs(fig_path, exist_ok=True)
os.makedirs(save_path, exist_ok=True)

# Hyper parameters
ep = 1000 # Training Epoch
o = 32 # Koopman Operator Size
m = 16 # Modes
r = 8 # Power of Koopman Matrix

# Model
koopman_model = koopman(backbone = "KNO2d", autoencoder = "Conv2d", o = o, m = m, r = r, t_in = 20, device = device)
koopman_model.compile()
koopman_model.opt_init("Adam", lr = 0.005, step_size=250, gamma=0.5)
koopman_model.train(epochs=ep, trainloader = train_loader, evalloader = test_loader)

# Result and Saving
time_error = koopman_model.test(test_loader, path = fig_path, is_save = True, is_plot = True)
filename = "koopmanAE" + str(o) + "m" + str(m) + "r" +str(r) + ".pt"
torch.save({"time_error":time_error,"params":koopman_model.params}, save_path + filename)

In [46]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os

# ファイルをロード
file_path = "./demo/fig/pred_yy_loc2.pt"

# データをロード
data = torch.load(file_path)
pred = data["pred"].cpu().numpy()  # 形状: [4, 64, 64, 80]
yy = data["yy"].cpu().numpy()      # 形状: [4, 64, 64, 80]

# 保存ディレクトリ作成
save_dir = "./demo/pred_yy_images/pred_yy_loc2"
os.makedirs(save_dir, exist_ok=True)

# 全データの最小値と最大値を取得（スケール統一）
diff_arr = pred - yy 
global_min_diff = diff_arr.min()
global_max_diff = diff_arr.max()

global_min = min(pred.min(), yy.min())
global_max = max(pred.max(), yy.max())

# バッチ・チャンネルごとに画像を保存
for batch_idx in range(pred.shape[0]):
    for channel in range(pred.shape[3]):  # 時間チャンネルごとに処理
        # 各画像を取得
        img_pred = pred[batch_idx, :, :, channel]
        img_yy = yy[batch_idx, :, :, channel]
        img_diff = img_pred - img_yy  # 差分画像

        # すべての画像を同じスケールで正規化
        img_pred_norm = (img_pred - global_min) / (global_max - global_min)
        img_yy_norm = (img_yy - global_min) / (global_max - global_min)
        img_diff_norm = (img_diff - global_min_diff) / (global_max_diff - global_min_diff)  # 差分画像は独立スケール

        # カラーマップ適用（rainbow）
        img_pred_colored = plt.cm.rainbow(np.clip(img_pred_norm, 0, 1))
        img_yy_colored = plt.cm.rainbow(np.clip(img_yy_norm, 0, 1))
        img_diff_colored = plt.cm.rainbow(np.clip(img_diff_norm, 0, 1))

        # 横に結合して1枚の画像に
        combined_img = np.hstack([img_pred_colored, img_yy_colored, img_diff_colored])

        # 画像保存
        img_filename = f"{save_dir}/batch{batch_idx}_channel{channel}.png"
        plt.imsave(img_filename, combined_img)

print("画像の保存が完了しました！")


## ほかの手法との比較

### AutoEncoder+動的モード分解

In [None]:
import torch
import torch.nn as nn
import torch.linalg as linalg

class DMD_Operator2D(nn.Module):
    def __init__(self, op_size):
        super(DMD_Operator2D, self).__init__()
        self.op_size = op_size
        self.A = None  # DMD行列

    def compute_dmd_matrix(self, X, Y):
        """ DMD 行列 A を計算 """
        U, S, Vh = linalg.svd(X, full_matrices=False)
        S_inv = torch.diag(1.0 / S)
        self.A = Y @ Vh.T @ S_inv @ U.T

    def forward(self, x):
        B, C, H, W = x.shape
        x_flat = x.view(B, C, -1)  # (B, C, H*W)
        if self.A is None:
            self.A = torch.eye(x_flat.shape[1], device=x.device)  # 初期状態
        x_next = torch.matmul(self.A, x_flat)
        return x_next.view(B, C, H, W)  # 元の形状に戻す

class KNO2d_DMD(nn.Module):
    def __init__(self, encoder, decoder, op_size, decompose=6, linear_type=True, normalization=False):
        super(KNO2d_DMD, self).__init__()
        self.op_size = op_size
        self.decompose = decompose
        self.enc = encoder
        self.dec = decoder
        self.dmd_layer = DMD_Operator2D(self.op_size)
        self.w0 = nn.Conv2d(op_size, op_size, 1)
        self.linear_type = linear_type
        self.normalization = normalization
        if self.normalization:
            self.norm_layer = nn.BatchNorm2d(op_size)
    
    def forward(self, x):
        # Reconstruct
        x_reconstruct = self.enc(x)
        x_reconstruct = torch.tanh(x_reconstruct)
        x_reconstruct = self.dec(x_reconstruct)
        
        # Predict
        x = self.enc(x)  # Encoder
        x = torch.tanh(x)
        x = x.permute(0, 3, 1, 2)
        x_w = x
        for i in range(self.decompose):
            x1 = self.dmd_layer(x)  # DMD Operator
            if self.linear_type:
                x = x + x1
            else:
                x = torch.tanh(x + x1)
        if self.normalization:
            x = torch.tanh(self.norm_layer(self.w0(x_w)) + x)
        else:
            x = torch.tanh(self.w0(x_w) + x)
        x = x.permute(0, 2, 3, 1)
        x = self.dec(x)  # Decoder
        return x, x_reconstruct


In [None]:
import torch
import os
import h5py
    
# Setting your computing device
torch.cuda.set_device(0)
device = torch.device("cuda")




fig_path = "./demo/fig/"
save_path = "./demo/result/"
os.makedirs(fig_path, exist_ok=True)
os.makedirs(save_path, exist_ok=True)

# Hyper parameters
ep = 1000 # Training Epoch
o = 32 # Koopman Operator Size
m = 16 # Modes
r = 8 # Power of Koopman Matrix

# Model
koopman_model = koopman(backbone = "DMD", autoencoder = "Conv2d", o = o, m = m, r = r, t_in = 20, device = device)
koopman_model.compile()
koopman_model.opt_init("Adam", lr = 0.005, step_size=250, gamma=0.5)
koopman_model.train(epochs=ep, trainloader = train_loader, evalloader = test_loader)

# Result and Saving
time_error = koopman_model.test(test_loader, path = fig_path, is_save = True, is_plot = True)
filename = "ns_time_error_op" + str(o) + "m" + str(m) + "r" +str(r) + ".pt"
torch.save({"time_error":time_error,"params":koopman_model.params}, save_path + filename)

### AutoEncoder+Mamba

In [None]:
class StateSpace2D(nn.Module):
    def __init__(self, hidden_size, state_size):
        super(StateSpace2D, self).__init__()
        self.hidden_size = hidden_size
        self.state_size = state_size
        
        # State space model parameters
        self.A = nn.Parameter(torch.randn(hidden_size, hidden_size))  # State transition matrix
        self.B = nn.Parameter(torch.randn(hidden_size, state_size))   # Input matrix
        self.C = nn.Parameter(torch.randn(state_size, hidden_size))   # Output matrix
        self.D = nn.Parameter(torch.randn(state_size, state_size))    # Direct transmission matrix

    def forward(self, x):
        batch_size, height, width, channels = x.shape
        # Reshape input for state space processing
        x_flat = x.reshape(batch_size, height * width, channels)
        
        # Initialize hidden state
        h = torch.zeros(batch_size, height * width, self.hidden_size, device=x.device)
        
        # State space update
        h_next = torch.bmm(h, self.A.expand(batch_size, -1, -1)) + \
                 torch.bmm(x_flat, self.B.expand(batch_size, -1, -1))
        
        # Output equation
        y = torch.bmm(h_next, self.C.transpose(0, 1).expand(batch_size, -1, -1)) + \
            torch.bmm(x_flat, self.D.expand(batch_size, -1, -1))
            
        # Reshape output back to original dimensions
        y = y.reshape(batch_size, height, width, channels)
        
        return y

class SSM2d(nn.Module):
    def __init__(self, encoder, decoder, op_size, decompose=6, 
                 linear_type=True, normalization=False, hidden_size=64):
        super(SSM2d, self).__init__()
        # Parameter
        self.op_size = op_size
        self.decompose = decompose
        
        # Layer Structure
        self.enc = encoder
        self.dec = decoder
        self.state_space = StateSpace2D(hidden_size=hidden_size, state_size=op_size)
        self.w0 = nn.Conv2d(op_size, op_size, 1)
        self.linear_type = linear_type
        self.normalization = normalization
        if self.normalization:
            self.norm_layer = torch.nn.BatchNorm2d(op_size)

    def forward(self, x):
        # Reconstruct
        x_reconstruct = self.enc(x)
        x_reconstruct = torch.tanh(x_reconstruct)
        x_reconstruct = self.dec(x_reconstruct)
        
        # Predict
        x = self.enc(x)  # Encoder
        x = torch.tanh(x)
        x_w = x.permute(0, 3, 1, 2)
        x = x
        
        # Apply state space model iteratively
        for i in range(self.decompose):
            x1 = self.state_space(x)  # State Space Model
            if self.linear_type:
                x = x + x1
            else:
                x = torch.tanh(x + x1)
        
        if self.normalization:
            x = torch.tanh(self.norm_layer(self.w0(x_w)) + x.permute(0, 3, 1, 2))
        else:
            x = torch.tanh(self.w0(x_w) + x.permute(0, 3, 1, 2))
            
        x = x.permute(0, 2, 3, 1)
        x = self.dec(x)  # Decoder
        return x, x_reconstruct

In [None]:
import torch
import os
import h5py
    
# Setting your computing device
torch.cuda.set_device(0)
device = torch.device("cuda")


fig_path = "./demo/fig/"
save_path = "./demo/result/"
os.makedirs(fig_path, exist_ok=True)
os.makedirs(save_path, exist_ok=True)

# Hyper parameters
ep = 1000 # Training Epoch
o = 32 # Koopman Operator Size
m = 16 # Modes
r = 8 # Power of Koopman Matrix

# Model
koopman_model = koopman(backbone = "SSM", autoencoder = "Conv2d", o = o, m = m, r = r, t_in = 20, device = device)
koopman_model.compile()
koopman_model.opt_init("Adam", lr = 0.005, step_size=250, gamma=0.5)
koopman_model.train(epochs=ep, trainloader = train_loader, evalloader = test_loader)

# Result and Saving
time_error = koopman_model.test(test_loader, path = fig_path, is_save = True, is_plot = True)
filename = "ns_time_error_op" + str(o) + "m" + str(m) + "r" +str(r) + ".pt"
torch.save({"time_error":time_error,"params":koopman_model.params}, save_path + filename)

### AutoEncoder無しのDMD

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from numpy import dot, multiply, diag, power
from numpy import pi, exp, sin, cos, cosh, tanh, real, imag
from numpy.linalg import inv, eig, pinv
from scipy.linalg import svd, svdvals
from scipy.integrate import odeint, ode, complex_ode
from warnings import warn
from scipy.linalg import svd
from scipy.interpolate import RectBivariateSpline
from scipy.interpolate import griddata
import scipy.integrate

from matplotlib import animation
from matplotlib import pyplot as plt
from pydmd import DMD
from pydmd.plotter import plot_modes_2D
from scipy import interpolate
def resize_with_cv2(disp_arr, target_size=(64, 64)):
    """OpenCVを使ってリサイズ"""
    h, w, c = disp_arr.shape
    resized_channels = [cv2.resize(disp_arr[:, :, i], target_size, interpolation=cv2.INTER_LINEAR) for i in range(c)]
    return np.stack(resized_channels, axis=-1)
def create_animation(arr, title='Time Series Animation', cmap='viridis', interval=200):
    """
    3D NumPy配列をアニメーションとして可視化
    
    Parameters:
    -----------
    arr : numpy.ndarray
        3D配列 (縦  横 時間 )
    title : str, optional
        アニメーションのタイトル
    cmap : str, optional
        カラーマップ
    interval : int, optional
        フレーム間隔（ミリ秒）
    """
    # フィギュアとアクシスを作成
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # カラーマップの範囲を固定
    vmin = np.nanmin(arr)
    vmax = np.nanmax(arr)
    
    # 最初のフレームを表示
    im = ax.imshow(arr[:,:,0], cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(im, ax=ax, label='Value')
    
    # タイトルと軸ラベル
    ax.set_title(f'{title} - Frame 0')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    
    # アニメーション更新関数
    def update(frame):
        im.set_array(arr[:,:,frame])
        ax.set_title(f'{title} - Frame {frame}')
        return [im]
    
    # アニメーションを作成
    anim = animation.FuncAnimation(
        fig, 
        update, 
        frames=arr.shape[2],  # 時間軸の長さ
        interval=interval,    # フレーム間隔
        blit=True             # パフォーマンス最適化
    )
    
    return anim

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
from numpy import dot, multiply, diag, power
from numpy import pi, exp, sin, cos, cosh, tanh, real, imag
from numpy.linalg import inv, eig, pinv
from scipy.linalg import svd, svdvals
from scipy.integrate import odeint, ode, complex_ode
from warnings import warn
from scipy.linalg import svd
from scipy.interpolate import RectBivariateSpline
from scipy.interpolate import griddata
import scipy.integrate

from matplotlib import animation
from pydmd import DMD
from pydmd.plotter import plot_modes_2D
from scipy import interpolate
import torch
from scipy import linalg

def calculate_r2_score(y_true, y_pred):
    """
    NaNを考慮したR2スコアを計算する関数
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        真値
    y_pred : numpy.ndarray
        予測値
        
    Returns:
    --------
    float
        R2スコア
    """
    # NaNを除外してマスクを作成
    mask = ~(np.isnan(y_true) | np.isnan(y_pred))
    
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    
    if len(y_true) == 0:
        return np.nan
    
    # R2スコアの計算
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    
    if ss_tot == 0:
        return np.nan
    
    return 1 - (ss_res / ss_tot)

def calculate_metrics(predictions, ground_truth):
    """
    予測結果と真値の精度評価を行う関数
    
    Parameters:
    -----------
    predictions : numpy.ndarray
        予測結果
    ground_truth : numpy.ndarray
        真値
        
    Returns:
    --------
    dict
        各種評価指標
    """
    # データを2次元に変形 (N, T, H, W) -> (N*H*W, T)
    pred_flat = predictions.reshape(-1, predictions.shape[-1])
    true_flat = ground_truth.reshape(-1, ground_truth.shape[-1])
    
    # NaNを除外したマスク
    mask = ~(np.isnan(pred_flat) | np.isnan(true_flat))
    
    # 有効なデータの割合を計算
    valid_ratio = np.mean(mask) * 100
    
    # マスクを適用してNaNを除外
    pred_valid = pred_flat[mask]
    true_valid = true_flat[mask]
    
    if len(true_valid) == 0:
        return {
            'MSE': np.nan,
            'RMSE': np.nan,
            'MAE': np.nan,
            'R2': np.nan,
            'Relative_Error': np.nan,
            'Max_Error': np.nan,
            'Valid_Data_Ratio': 0.0
        }
    
    # 各種指標の計算
    mse = np.mean((pred_valid - true_valid) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(pred_valid - true_valid))
    r2 = calculate_r2_score(true_valid, pred_valid)
    rel_error = np.mean(np.abs(pred_valid - true_valid) / (np.abs(true_valid) + 1e-8))
    max_error = np.max(np.abs(pred_valid - true_valid))
    
    metrics = {
        'MSE': mse,
        'RMSE': rmse,
        'MAE': mae,
        'R2': r2,
        'Relative_Error': rel_error,
        'Max_Error': max_error,
        'Valid_Data_Ratio': valid_ratio
    }
    
    return metrics

def calculate_temporal_metrics(predictions, ground_truth):
    """
    時間方向の評価指標を計算する関数
    
    Parameters:
    -----------
    predictions : numpy.ndarray
        予測結果
    ground_truth : numpy.ndarray
        真値
        
    Returns:
    --------
    dict
        時間方向の評価指標
    """
    # 時間方向のMSEを計算
    temporal_mse = np.zeros(predictions.shape[-1])
    temporal_mae = np.zeros(predictions.shape[-1])
    temporal_r2 = np.zeros(predictions.shape[-1])
    
    for t in range(predictions.shape[-1]):
        # 各時刻でのデータ
        pred_t = predictions[..., t].flatten()
        true_t = ground_truth[..., t].flatten()
        
        # NaNを除外
        mask = ~(np.isnan(pred_t) | np.isnan(true_t))
        pred_valid = pred_t[mask]
        true_valid = true_t[mask]
        
        if len(true_valid) > 0:
            temporal_mse[t] = np.mean((pred_valid - true_valid) ** 2)
            temporal_mae[t] = np.mean(np.abs(pred_valid - true_valid))
            temporal_r2[t] = calculate_r2_score(true_valid, pred_valid)
    
    return {
        'temporal_mse': temporal_mse,
        'temporal_mae': temporal_mae,
        'temporal_r2': temporal_r2
    }






# File path
parent_dir = "E:/2024/koopman/sbas/result"
"""フォルダから変位データを読み込み、128×128にリサイズ"""
pred_data = []
gt_data = []

for subfolder in sorted(os.listdir(parent_dir)):
    subfolder_path = os.path.join(parent_dir, subfolder)
    
    if not os.path.isdir(subfolder_path):
        continue
        
    h5_file_path = os.path.join(subfolder_path, "cum_filt.h5")
    
    if not os.path.exists(h5_file_path):
        print(f"File not found: {h5_file_path}")
        continue
        
    print(f"Loading: {h5_file_path}")
    with h5py.File(h5_file_path, "r") as cumh5:
        disp_arr_chw = cumh5['cum']
        disp_arr = np.transpose(disp_arr_chw, (1, 2, 0))        # C, H, W => H,W,C
        if(disp_arr.shape[2]<100):continue
        vel = cumh5['vel']

        nanmean_time = np.nanmean(disp_arr, axis=2)
        disp_arr = np.nan_to_num(disp_arr, nan=0.0)  # NaN をゼロに置き換える
        disp_arr = resize_with_cv2(disp_arr)  # 線形補間
        h = disp_arr.shape[0]
        w = disp_arr.shape[1]
        c = 20
    
        inputdata = disp_arr[:,:,:20]
        gtdata = disp_arr[:,:,20:100]
        
        # DMD
        dmd = DMD(svd_rank=1, tlsq_rank=2, exact=True, opt=True)

        dmd.fit(inputdata)
    
        #　きちんと復元できているかを確認する
        arr_3d = dmd.reconstructed_data.reshape(h, w, c).real
        print("Shape arr_3d: {}".format(arr_3d.shape))

        #######################################
        ############ ここから予測  ##############
        #######################################
        dmd.dmd_time["tend"] += 80 # 300ステップ後まで将来予測

        # 予測結果の配列を画像形式に戻す
        arr_3d_pred = dmd.reconstructed_data.reshape(h, w, c + 80).real
        arr_3d_pred_CHW = np.transpose(arr_3d_pred, (2, 0, 1))        # H, W, C => C,H,W
        print("Shape after manipulation: {}".format(arr_3d_pred.shape))
        arr_3d_pred = np.nan_to_num(arr_3d_pred, nan=0.0)  # NaN をゼロに置き換える
        pred_data.append(arr_3d_pred)
        gt_data.append(gtdata)

    
all_predictions_u = np.array(pred_data)
all_predictions_u = all_predictions_u[:,:,:,20:]
all_ground_truth_u = np.array(gt_data)

# 精度評価
metrics = calculate_metrics(all_predictions_u, all_ground_truth_u)
temporal_metrics = calculate_temporal_metrics(all_predictions_u, all_ground_truth_u)
# # 結果の表示
print("\nAccuracy Metrics:")
print("-" * 40)
for metric_name, value in metrics.items():
    print(f"{metric_name:15s}: {value:.6f}")

print("\nTemporal Error Analysis:")
print("-" * 40)
print(f"Best timestep (MSE): {np.argmin(temporal_metrics['temporal_mse'])} "
      f"(MSE: {np.min(temporal_metrics['temporal_mse']):.6f})")
print(f"Worst timestep (MSE): {np.argmax(temporal_metrics['temporal_mse'])} "
      f"(MSE: {np.max(temporal_metrics['temporal_mse']):.6f})")
print(f"Average R2 score: {np.mean(temporal_metrics['temporal_r2']):.6f}")



### AutoEncoder無しのSSM

In [None]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from pykalman import KalmanFilter

# データの読み込みと前処理
def load_and_preprocess_data(parent_dir):
    input_data, gt_data = [], []
    
    for subfolder in sorted(os.listdir(parent_dir)):
        subfolder_path = os.path.join(parent_dir, subfolder)
        if not os.path.isdir(subfolder_path):
            continue
        
        h5_file_path = os.path.join(subfolder_path, "cum_filt.h5")
        
        if not os.path.exists(h5_file_path):
            print(f"File not found: {h5_file_path}")
            continue
        
        with h5py.File(h5_file_path, "r") as cumh5:
            disp_arr = np.transpose(cumh5['cum'], (1, 2, 0))  # C, H, W -> H, W, C
            disp_arr = np.nan_to_num(disp_arr, nan=0.0)  # NaN をゼロに置き換える
            disp_arr = resize_with_cv2(disp_arr)
            if disp_arr.shape[2] < 100:
                continue
            
            input_seq = disp_arr[:, :, :20]
            gt_seq = disp_arr[:, :, 20:100]
            
            input_data.append(input_seq)
            gt_data.append(gt_seq)
    
    return np.array(input_data), np.array(gt_data)

# 状態空間モデルによる将来予測
def predict_with_ssm(input_data, future_steps=80):
    height, width, time_steps = input_data.shape
    predictions = np.zeros((height, width, time_steps + future_steps))
    predictions[:, :, :time_steps] = input_data
    
    for i in range(height):
        for j in range(width):
            ts = input_data[i, j, :]
            ts = np.nan_to_num(ts, nan=0.0)  # NaNをゼロに置き換え
            
            kf = KalmanFilter(
                transition_matrices=[[1]],  # 状態遷移行列
                observation_matrices=[[1]],  # 観測行列
                initial_state_mean=ts[0],
                n_dim_obs=1
            )
            kf = kf.em(ts, n_iter=10)  # EMアルゴリズムでパラメータ推定
            
            filtered_state_means, _ = kf.filter(ts)
            future_state_means, _ = kf.smooth(np.concatenate([ts, np.zeros(future_steps)]))
            
            predictions[i, j, :] = future_state_means[:, 0]
    
    return predictions

# 精度評価
def calculate_metrics(predictions, ground_truth):
    mse = np.mean((predictions - ground_truth) ** 2)
    rmse = np.sqrt(mse)
    return {'MSE': mse, 'RMSE': rmse}

# 実行
data_dir = "E:/2024/koopman/sbas/result"
input_data, ground_truth = load_and_preprocess_data(data_dir)
predictions=[]
for one_data in input_data:
    # print()
    predictions.append(predict_with_ssm(one_data))
predictions=np.array(predictions)  
metrics = calculate_metrics(predictions[:, :, 20:], ground_truth)

print("Prediction Metrics:")
for key, value in metrics.items():
    print(f"{key}: {value:.6f}")

### AutoEncoder無しのkoopman

In [None]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
def resize_with_cv2(disp_arr, target_size=(64, 64)):
    """OpenCVを使ってリサイズ"""
    h, w, c = disp_arr.shape
    resized_channels = [cv2.resize(disp_arr[:, :, i], target_size, interpolation=cv2.INTER_LINEAR) for i in range(c)]
    return np.stack(resized_channels, axis=-1)
    
# データの読み込みと前処理
def load_and_preprocess_data(parent_dir):
    input_data, gt_data = [], []
    count=0
    for subfolder in sorted(os.listdir(parent_dir)):
        subfolder_path = os.path.join(parent_dir, subfolder)
        if not os.path.isdir(subfolder_path):
            continue
        
        h5_file_path = os.path.join(subfolder_path, "cum_filt.h5")
        if not os.path.exists(h5_file_path):
            print(f"File not found: {h5_file_path}")
            continue
        
        with h5py.File(h5_file_path, "r") as cumh5:
            disp_arr = np.transpose(cumh5['cum'], (1, 2, 0))  # C, H, W -> H, W, C
            disp_arr = np.nan_to_num(disp_arr, nan=0.0)  # NaN をゼロに置き換える
            disp_arr = resize_with_cv2(disp_arr)
            
            
            if disp_arr.shape[2] < 100:
                continue
            
            input_seq = disp_arr[:, :, :20]
            gt_seq = disp_arr[:, :, 20:100]
            
            input_data.append(input_seq)
            gt_data.append(gt_seq)
            count+=1

    return np.array(input_data), np.array(gt_data)

# クープマン作用素を用いた将来予測
def koopman_forecast(input_data, future_steps=80):
    height, width, time_steps = input_data.shape
    predictions = np.zeros((height, width, time_steps + future_steps))
    predictions[:, :, :time_steps] = input_data
    
    for i in range(height):
        for j in range(width):
            ts = input_data[i, j, :]
            ts = np.nan_to_num(ts, nan=0.0)  # NaNをゼロに置き換え
            
            X = ts[:-1].reshape(-1, 1)  # 過去の状態
            Y = ts[1:].reshape(-1, 1)  # 次の時刻の状態
            
            # クープマン作用素の近似として線形回帰を使用
            K = np.linalg.pinv(X) @ Y  # ダイナミクス行列の推定
            
            # 未来の状態の予測
            future_state = np.array([ts[-1]])  # 1D ベクトルとして扱う
            for t in range(future_steps):
                future_state = K @ future_state  # 行列積を適用
                predictions[i, j, time_steps + t] = future_state[0]  # スカラー値に戻して格納
    # print(predictions.shape)

    
    return predictions

# 精度評価
def calculate_metrics(predictions, ground_truth):
    mse = np.mean((predictions - ground_truth) ** 2)
    rmse = np.sqrt(mse)
    return {'MSE': mse, 'RMSE': rmse}

# 実行
data_dir = "E:/2024/koopman/sbas/result"
input_data, ground_truth = load_and_preprocess_data(data_dir)
predictions = []
for one_data in input_data:
    predictions.append(koopman_forecast(one_data))
predictions = np.array(predictions)  
metrics = calculate_metrics(predictions[:, :, :, 20:], ground_truth)

print("Prediction Metrics:")
for key, value in metrics.items():
    print(f"{key}: {value:.6f}")
