In [1]:
# %% [markdown]
# # 期权定价的深度随机神经网络
# 
# 本Notebook整合了基于Forward-Backward Stochastic Neural Networks (FBSNNs)的期权定价模型，支持苹果M芯片(MPS)、TPU和GPU加速。

# %% [markdown]
# ## 1. 导入必要的库和设备设置

# %%
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
from abc import ABC, abstractmethod
import os
import warnings

# 抑制警告
warnings.filterwarnings('ignore')
#-------- todo: continue from here --------
# 设备检测和设置
def setup_device():
    """自动检测并设置最佳计算设备"""
    # 检测TPU
    if 'COLAB_TPU_ADDR' in os.environ:
        try:
            import torch_xla
            import torch_xla.core.xla_model as xm
            device = xm.xla_device()
            print(f"使用TPU设备: {device}")
            return device, 'tpu'
        except ImportError:
            print("TPU检测到但torch_xla未安装，使用默认设备")
    
    # 检测MPS (苹果芯片)
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print(f"使用MPS设备: {device}")
        return device, 'mps'
    
    # 检测CUDA
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"使用CUDA设备: {device}")
        return device, 'cuda'
    
    # 默认CPU
    device = torch.device("cpu")
    print(f"使用CPU设备: {device}")
    return device, 'cpu'

# 设置设备
device, device_type = setup_device()

# TPU特定设置
if device_type == 'tpu':
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
# %% [markdown]
# ## 2. 定义激活函数和神经网络结构

# %%
class Sine(nn.Module):
    """正弦激活函数"""
    def __init__(self):
        super(Sine, self).__init__()

    def forward(self, x):
        return torch.sin(x)

class Naisnet(nn.Module):
    """NAIS-Net神经网络结构 - 针对加速设备优化"""
    def __init__(self, layers, stable=True, activation=None, device=None):
        super(Naisnet, self).__init__()
        
        self.layers = layers
        self.device = device
        self.layer1 = nn.Linear(in_features=layers[0], out_features=layers[1])
        self.layer2 = nn.Linear(in_features=layers[1], out_features=layers[2])
        self.layer2_input = nn.Linear(in_features=layers[0], out_features=layers[2])
        self.layer3 = nn.Linear(in_features=layers[2], out_features=layers[3])
        
        if len(layers) == 5:
            self.layer3_input = nn.Linear(in_features=layers[0], out_features=layers[3])
            self.layer4 = nn.Linear(in_features=layers[3], out_features=layers[4])
        elif len(layers) == 6:
            self.layer3_input = nn.Linear(in_features=layers[0], out_features=layers[3])
            self.layer4 = nn.Linear(in_features=layers[3], out_features=layers[4])
            self.layer4_input = nn.Linear(in_features=layers[0], out_features=layers[4])
            self.layer5 = nn.Linear(in_features=layers[4], out_features=layers[5])

        self.activation = activation
        self.epsilon = 0.01
        self.stable = stable
        
        # 初始化权重
        self._init_weights()
        
        # 移动到设备
        self.to(device)
    
    def _init_weights(self):
        """初始化权重"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def project(self, layer, out):
        """NAIS-Net的投影层 - 针对TPU/MPS优化"""
        weights = layer.weight
        delta = 1 - 2 * self.epsilon
        RtR = torch.matmul(weights.t(), weights)
        norm = torch.norm(RtR)
        if norm > delta:
            RtR = delta ** (1 / 2) * RtR / (norm ** (1 / 2))
        
        A = RtR + torch.eye(RtR.shape[0], device=RtR.device) * self.epsilon
        return F.linear(out, -A, layer.bias)

    def forward(self, x):
        u = x

        out = self.layer1(x)
        out = self.activation(out)

        shortcut = out
        if self.stable:
            out = self.project(self.layer2, out)
            out = out + self.layer2_input(u)
        else:
            out = self.layer2(out)
        out = self.activation(out)
        out = out + shortcut

        if len(self.layers) == 4:
            out = self.layer3(out)
            return out

        if len(self.layers) == 5:
            shortcut = out
            if self.stable:
                out = self.project(self.layer3, out)
                out = out + self.layer3_input(u)
            else:
                out = self.layer3(out)
            out = self.activation(out)
            out = out + shortcut

            out = self.layer4(out)
            return out
        
        if len(self.layers) == 6:
            shortcut = out
            if self.stable:
                out = self.project(self.layer3, out)
                out = out + self.layer3_input(u)
            else:
                out = self.layer3(out)
            out = self.activation(out)
            out = out + shortcut

            shortcut = out
            if self.stable:
                out = self.project(self.layer4, out)
                out = out + self.layer4_input(u)
            else:
                out = self.layer4(out)

            out = self.activation(out)
            out = out + shortcut

            out = self.layer5(out)
            return out

        return out

# %% [markdown]
# ## 3. 定义FBSNN基类（支持多设备加速）

# %%
class FBSNN(ABC):
    """Forward-Backward Stochastic Neural Network基类 - 多设备支持"""
    
    def __init__(self, Xi, T, M, N, D, Mm, layers, mode, activation):
        # 设备设置
        self.device = device
        self.device_type = device_type
        
        # 初始化条件
        self.Xi = torch.from_numpy(Xi).float().to(self.device)
        self.Xi.requires_grad = True

        # 存储参数
        self.T = T
        self.M = M
        self.N = N
        self.D = D
        self.Mm = Mm
        self.strike = 1.0 * self.D

        self.mode = mode
        self.activation = activation
        
        # 设置激活函数
        if activation == "Sine":
            self.activation_function = Sine()
        elif activation == "ReLU":
            self.activation_function = nn.ReLU()
        elif activation == "Tanh":
            self.activation_function = nn.Tanh()

        # 初始化神经网络
        if self.mode == "FC":
            self.layers_list = []
            for i in range(len(layers) - 2):
                self.layers_list.append(nn.Linear(in_features=layers[i], out_features=layers[i + 1]))
                self.layers_list.append(self.activation_function)
            self.layers_list.append(nn.Linear(in_features=layers[-2], out_features=layers[-1]))
            self.model = nn.Sequential(*self.layers_list).to(self.device)

        elif self.mode == "Naisnet":
            self.model = Naisnet(layers, stable=True, activation=self.activation_function, device=self.device)

        # 训练记录
        self.training_loss = []
        self.iteration = []
        
        # 性能优化标志
        self.optimize_for_device()

    def optimize_for_device(self):
        """根据设备类型进行特定优化"""
        if self.device_type == 'tpu':
            # TPU优化设置
            torch.set_grad_enabled(True)
            # 启用XLA动态形状（提高TPU性能）
            os.environ['XLA_USE_BF16'] = '1'
            os.environ['XLA_USE_F16'] = '1'
            
        elif self.device_type == 'mps':
            # MPS优化设置
            # 注意：torch.backends.mps没有set_memory_efficient方法
            # 可以使用以下MPS特定的优化
            
            # 启用自动混合精度（如果可用）
            if hasattr(torch, 'autocast'):
                # 这将用于在后续训练中使用混合精度
                self.use_amp = True
            else:
                self.use_amp = False
            
            # 清空MPS缓存
            if hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
                
        elif self.device_type == 'cuda':
            # CUDA优化设置
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            self.use_amp = True  # 启用自动混合精度

    def weights_init(self, m):
        """权重初始化"""
        if type(m) == nn.Linear:
            if self.device_type == 'tpu':
                # TPU上使用更稳定的初始化
                torch.nn.init.xavier_normal_(m.weight)
            else:
                torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def net_u(self, t, X):
        """神经网络前向传播 - 设备优化版本"""
        # 使用torch.cat的优化版本
        input = torch.cat((t, X), 1)
        u = self.model(input)
        
        # 梯度计算优化
        if u.requires_grad:
            Du = torch.autograd.grad(outputs=u, inputs=X, 
                                   grad_outputs=torch.ones_like(u),
                                   create_graph=True, retain_graph=True)[0]
        else:
            # 如果不需要梯度，创建零张量
            Du = torch.zeros_like(X)
        return u, Du

    def Dg_tf(self, X):
        """计算g函数的梯度 - 优化版本"""
        g = self.g_tf(X)
        if g.requires_grad:
            Dg = torch.autograd.grad(outputs=g, inputs=X, 
                                   grad_outputs=torch.ones_like(g),
                                   create_graph=True, retain_graph=True)[0]
        else:
            Dg = torch.zeros_like(X)
        return Dg

    @torch.no_grad()
    def fetch_minibatch(self):
        """生成小批量数据 - 设备优化版本"""
        T = self.T
        M = self.M
        N = self.N
        D = self.D

        # 预分配内存
        Dt = np.zeros((M, N + 1, 1), dtype=np.float32)
        DW = np.zeros((M, N + 1, D), dtype=np.float32)

        dt = T / N
        Dt[:, 1:, :] = dt
        
        # 向量化随机数生成
        DW[:, 1:, :] = np.sqrt(dt) * np.random.randn(M, N, D).astype(np.float32)

        t = np.cumsum(Dt, axis=1)
        W = np.cumsum(DW, axis=1)

        # 直接创建在目标设备上
        t_tensor = torch.from_numpy(t).float().to(self.device)
        W_tensor = torch.from_numpy(W).float().to(self.device)

        return t_tensor, W_tensor

    def loss_function(self, t, W, Xi, training=True):
        """计算损失函数 - 性能优化版本"""
        if training:
            loss = torch.tensor(0.0, device=self.device, requires_grad=True)
        else:
            loss = torch.tensor(0.0, device=self.device, requires_grad=False)
            
        X_list = []
        Y_list = []

        t0 = t[:, 0, :]
        W0 = W[:, 0, :]
        X0 = Xi.repeat(self.M, 1).view(self.M, self.D)
        Y0, Z0 = self.net_u(t0, X0)

        X_list.append(X0)
        Y_list.append(Y0)

        # 使用更高效的内存管理
        for n in range(self.N):
            t1 = t[:, n + 1, :]
            W1 = W[:, n + 1, :]
            
            # 向量化计算
            mu = self.mu_tf(t0, X0, Y0, Z0)
            sigma = self.sigma_tf(t0, X0, Y0)
            dW = (W1 - W0).unsqueeze(-1)
            
            # 使用einsum提高计算效率
            sigma_dW = torch.einsum('mij,mj->mi', sigma, dW.squeeze(-1))
            X1 = X0 + mu * (t1 - t0) + sigma_dW
            
            Z_sigma_dW = torch.sum(Z0 * sigma_dW, dim=1, keepdim=True)
            Y1_tilde = Y0 + self.phi_tf(t0, X0, Y0, Z0) * (t1 - t0) + Z_sigma_dW
            
            Y1, Z1 = self.net_u(t1, X1)
            
            if training:
                loss = loss + torch.sum((Y1 - Y1_tilde) ** 2)
            else:
                loss = loss + torch.sum((Y1 - Y1_tilde) ** 2).detach()

            # 更新变量（避免不必要的拷贝）
            t0, W0, X0, Y0, Z0 = t1, W1, X1, Y1, Z1
            X_list.append(X0)
            Y_list.append(Y0)

        # 终端条件损失
        if training:
            terminal_loss = torch.sum((Y1 - self.g_tf(X1)) ** 2)
            gradient_loss = torch.sum((Z1 - self.Dg_tf(X1)) ** 2)
            loss = loss + terminal_loss + gradient_loss
        else:
            terminal_loss = torch.sum((Y1 - self.g_tf(X1)) ** 2).detach()
            gradient_loss = torch.sum((Z1 - self.Dg_tf(X1)) ** 2).detach()
            loss = loss + terminal_loss + gradient_loss

        X = torch.stack(X_list, dim=1)
        Y = torch.stack(Y_list, dim=1)

        return loss, X, Y, Y[0, 0, 0]

    def train(self, N_Iter, learning_rate, accumulation_steps=1):
        """训练模型 - 多设备优化版本"""
        loss_temp = []
        previous_it = 0 if not self.iteration else self.iteration[-1]

        # 设备特定的优化器设置
        if self.device_type == 'tpu':
            self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, 
                                       weight_decay=1e-4, fused=True)
        else:
            self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, 
                                       weight_decay=1e-4)
        
        # 学习率调度器
        scheduler = optim.lr_scheduler.OneCycleLR(
            self.optimizer, max_lr=learning_rate, 
            total_steps=N_Iter, pct_start=0.1
        )

        start_time = time.time()
        gradient_accumulation_counter = 0
        
        for it in range(previous_it, previous_it + N_Iter):
            # 动态调整时间步数
            if it >= 4000 and it < 20000:
                self.N = int(np.ceil(self.Mm ** (int(it / 4000) + 1)))
            elif it < 4000:
                self.N = int(np.ceil(self.Mm))

            self.optimizer.zero_grad()
            t_batch, W_batch = self.fetch_minibatch()
            
            # 使用自动混合精度（如果可用）
            if hasattr(self, 'use_amp') and self.use_amp and self.device_type != 'tpu':
                with torch.autocast(device_type=self.device_type if self.device_type != 'mps' else 'cpu', dtype=torch.float16):
                    loss, X_pred, Y_pred, Y0_pred = self.loss_function(t_batch, W_batch, self.Xi, training=True)
            else:
                loss, X_pred, Y_pred, Y0_pred = self.loss_function(t_batch, W_batch, self.Xi, training=True)
            
            # 梯度累积
            scaled_loss = loss / accumulation_steps
            scaled_loss.backward()
            gradient_accumulation_counter += 1
            
            if gradient_accumulation_counter == accumulation_steps:
                if self.device_type == 'tpu':
                    # TPU特定的梯度裁剪
                    xm.optimizer_step(self.optimizer)
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()
                
                scheduler.step()
                gradient_accumulation_counter = 0
                
                # 清空MPS缓存（如果使用MPS设备）
                if self.device_type == 'mps' and it % 100 == 0:
                    if hasattr(torch.mps, 'empty_cache'):
                        torch.mps.empty_cache()

            # 设备特定的损失记录
            if self.device_type == 'tpu':
                loss_value = loss.detach().cpu().item()
            else:
                loss_value = loss.detach().item()
                
            loss_temp.append(loss_value)
            
            # 进度报告
            if it % 100 == 0:
                elapsed = time.time() - start_time
                current_lr = self.optimizer.param_groups[0]['lr']
                
                if self.device_type == 'tpu':
                    print(f'迭代: {it}, 损失: {loss_value:.3e}, Y0: {Y0_pred:.3f}, '
                          f'时间: {elapsed:.2f}s, 学习率: {current_lr:.3e}, 设备: TPU')
                else:
                    print(f'迭代: {it}, 损失: {loss_value:.3e}, Y0: {Y0_pred:.3f}, '
                          f'时间: {elapsed:.2f}s, 学习率: {current_lr:.3e}, 设备: {self.device_type.upper()}')
                
                start_time = time.time()

            # 记录损失
            if it % 100 == 0:
                avg_loss = np.mean(loss_temp)
                self.training_loss.append(avg_loss)
                loss_temp = []
                self.iteration.append(it)
                
                # TPU特定操作
                if self.device_type == 'tpu' and it % 1000 == 0:
                    xm.mark_step()
        
        # 最终标记步骤（TPU）
        if self.device_type == 'tpu':
            xm.mark_step()
                
        return np.stack((self.iteration, self.training_loss))

    def predict(self, Xi_star, t_star, W_star):
        """预测 - 设备优化版本"""
        Xi_star = torch.from_numpy(Xi_star).float().to(self.device)
        Xi_star.requires_grad = False  # 预测时不需要梯度
        
        # 设置模型为评估模式
        self.model.eval()
        
        with torch.no_grad():
            _, X_star, Y_star, _ = self.loss_function(t_star, W_star, Xi_star, training=False)
        
        # 恢复训练模式
        self.model.train()
        
        return X_star, Y_star

    def save_model(self, file_name):
        """保存模型 - 多设备兼容"""
        state = {
            'model_state_dict': self.model.state_dict(),
            'training_loss': self.training_loss,
            'iteration': self.iteration,
            'device_type': self.device_type
        }
        torch.save(state, file_name)

    def load_model(self, file_name):
        """加载模型 - 多设备兼容"""
        checkpoint = torch.load(file_name, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.training_loss = checkpoint['training_loss']
        self.iteration = checkpoint['iteration']

    @abstractmethod
    def phi_tf(self, t, X, Y, Z):
        pass

    @abstractmethod
    def g_tf(self, X):
        pass

    @abstractmethod
    def mu_tf(self, t, X, Y, Z):
        M = self.M
        D = self.D
        return torch.zeros([M, D], device=self.device)

    @abstractmethod
    def sigma_tf(self, t, X, Y):
        M = self.M
        D = self.D
        return torch.diag_embed(torch.ones([M, D], device=self.device))

# %% [markdown]
# ## 4. 定义看涨期权类

# %%
class CallOption(FBSNN):
    """看涨期权定价模型 - 多设备优化版本"""
    
    def __init__(self, Xi, T, M, N, D, Mm, layers, mode, activation):
        super().__init__(Xi, T, M, N, D, Mm, layers, mode, activation)

    def phi_tf(self, t, X, Y, Z):
        """漂移项"""
        rate = 0.01
        return rate * Y

    def g_tf(self, X):
        """终端条件"""
        temp = torch.sum(X, dim=1, keepdim=True)
        return torch.maximum(temp - self.strike, torch.tensor(0.0, device=self.device))

    def mu_tf(self, t, X, Y, Z):
        """漂移系数"""
        rate = 0.01
        return rate * X

    def sigma_tf(self, t, X, Y):
        """扩散系数"""
        sigma = 0.25
        return sigma * torch.diag_embed(X)

# %% [markdown]
# ## 5. 辅助函数

# %%
def black_scholes_call(S, K, T, r, sigma, q=0):
    """Black-Scholes看涨期权定价公式"""
    from scipy.stats import norm
    
    if T <= 0:
        return max(S - K, 0), 1.0 if S > K else 0.0
    
    d1 = (np.log(S / K) + (r - q + 0.5 * sigma ** 2) * T) / (sigma * np.sqrt(T))
    d2 = d1 - sigma * np.sqrt(T)
    call_price = (S * np.exp(-q * T) * norm.cdf(d1)) - (K * np.exp(-r * T) * norm.cdf(d2))
    delta = norm.cdf(d1)
    return call_price, delta

def calculate_option_prices(X_pred, time_array, K, r, sigma, T, q=0):
    """计算期权价格和Delta"""
    rows, cols = X_pred.shape
    option_prices = np.zeros((rows, cols))
    deltas = np.zeros((rows, cols))

    for i in range(rows):
        for j in range(cols):
            S = X_pred[i, j]
            t = time_array[j]
            time_to_maturity = T - t
            option_prices[i, j], deltas[i, j] = black_scholes_call(S, K, time_to_maturity, r, sigma, q)

    return option_prices, deltas

def figsize(scale, nplots=1):
    """设置图形大小"""
    fig_width_pt = 438.17227
    inches_per_pt = 1.0 / 72.27
    golden_mean = (np.sqrt(5.0) - 1.0) / 2.0
    fig_width = fig_width_pt * inches_per_pt * scale
    fig_height = nplots * fig_width * golden_mean
    return [fig_width, fig_height]

def check_device_memory():
    """检查设备内存使用情况"""
    if device_type == 'cuda':
        print(f"GPU内存使用: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
        print(f"GPU内存缓存: {torch.cuda.memory_reserved()/1024**3:.2f} GB")
    elif device_type == 'mps':
        # MPS设备内存信息不可直接获取，但我们可以检查系统内存
        import psutil
        memory_info = psutil.virtual_memory()
        print(f"系统总内存: {memory_info.total/1024**3:.2f} GB")
        print(f"系统可用内存: {memory_info.available/1024**3:.2f} GB")
        print(f"系统使用率: {memory_info.percent}%")
    elif device_type == 'tpu':
        print("TPU设备内存信息不可直接获取")
    else:
        import psutil
        memory_info = psutil.virtual_memory()
        print(f"系统总内存: {memory_info.total/1024**3:.2f} GB")
        print(f"系统可用内存: {memory_info.available/1024**3:.2f} GB")

# %% [markdown]
# ## 6. 模型训练和测试

# %%
# 设置参数（根据设备类型调整）
if device_type == 'tpu':
    M = 512  # TPU适合大批量
    accumulation_steps = 4
elif device_type == 'cuda':
    M = 256  # GPU中等批量
    accumulation_steps = 2
elif device_type == 'mps':
    M = 128  # MPS较小批量
    accumulation_steps = 1
else:
    M = 64   # CPU小批量
    accumulation_steps = 1

N = 50
D = 1
Mm = N ** (1/5)

layers = [D + 1] + 4 * [256] + [1]

Xi = np.array([1.0] * D)[None, :]
T = 1.0

mode = "Naisnet"
activation = "Sine"

print(f"设备类型: {device_type}")
print(f"批量大小: {M}")
print(f"累积步数: {accumulation_steps}")

# 创建模型
model = CallOption(Xi, T, M, N, D, Mm, layers, mode, activation)

# %%
# 检查设备内存
check_device_memory()

# %%
# 第一阶段训练
print("开始第一阶段训练...")
n_iter = 500  # 减少迭代次数用于演示
lr = 1e-3

tot = time.time()
graph = model.train(n_iter, lr, accumulation_steps)
print(f"第一阶段训练完成，总时间: {time.time() - tot:.2f}s")

# %%
# 第二阶段训练（精细调优）
print("开始第二阶段训练...")
n_iter = 300  # 减少迭代次数用于演示
lr = 1e-5

tot = time.time()
graph = model.train(n_iter, lr, accumulation_steps)
print(f"第二阶段训练完成，总时间: {time.time() - tot:.2f}s")

# 测试模型
print("开始测试...")
np.random.seed(37)
t_test, W_test = model.fetch_minibatch()
X_pred, Y_pred = model.predict(Xi, t_test, W_test)

# 转换为numpy数组
if device_type == 'tpu':
    t_test = t_test.cpu().numpy()
    X_pred = X_pred.cpu().numpy()
    Y_pred = Y_pred.cpu().numpy()
else:
    t_test = t_test.cpu().numpy()
    X_pred = X_pred.detach().cpu().numpy()
    Y_pred = Y_pred.detach().cpu().numpy()

# 收集测试数据
test_samples = 2 if device_type == 'tpu' else 3  # 减少测试样本
for i in range(test_samples):
    t_test_i, W_test_i = model.fetch_minibatch()
    X_pred_i, Y_pred_i = model.predict(Xi, t_test_i, W_test_i)
    
    if device_type == 'tpu':
        t_test_i = t_test_i.cpu().numpy()
        X_pred_i = X_pred_i.cpu().numpy()
        Y_pred_i = Y_pred_i.cpu().numpy()
    else:
        t_test_i = t_test_i.cpu().numpy()
        X_pred_i = X_pred_i.detach().cpu().numpy()
        Y_pred_i = Y_pred_i.detach().cpu().numpy()
        
    t_test = np.concatenate((t_test, t_test_i), axis=0)
    X_pred = np.concatenate((X_pred, X_pred_i), axis=0)
    Y_pred = np.concatenate((Y_pred, Y_pred_i), axis=0)

X_pred = X_pred[:50]  # 限制数据量

# 计算Black-Scholes基准值
from scipy.stats import norm  # 导入norm函数
K = 1.0
r = 0.01
sigma = 0.25
q = 0
T = 1

# 确保X_preds形状正确
X_preds = X_pred[:, :, 0] if len(X_pred.shape) == 3 else X_pred
Y_test, Z_test = calculate_option_prices(X_preds, t_test[0, :, 0], K, r, sigma, T, q)

# 确保Y_pred形状正确
Y_pred_reshaped = Y_pred[:, :, 0] if len(Y_pred.shape) == 3 else Y_pred
errors = (Y_test[:50] - Y_pred_reshaped[:50])**2
print(f"均方误差: {errors.mean():.6f}")
print(f"误差标准差: {errors.std():.6f}")
print(f"均方根误差: {np.sqrt(errors.mean()):.6f}")

# %% [markdown]
# ## 7. 结果可视化

# %%
# 绘制训练损失
plt.figure(figsize=figsize(1.0))
plt.plot(graph[0], graph[1])
plt.xlabel('迭代次数')
plt.ylabel('损失值')
plt.yscale("log")
plt.title('训练损失变化')
plt.grid(True, alpha=0.3)
plt.show()

# 绘制预测结果对比
plt.figure(figsize=figsize(1.0))
samples = min(5, len(t_test))

# 绘制学习到的解
for i in range(samples):
    plt.plot(t_test[i, :, 0], Y_pred_reshaped[i, :], 'b-', alpha=0.7, linewidth=1.5)

# 绘制精确解
for i in range(samples):
    plt.plot(t_test[i, :, 0], Y_test[i, :], 'r--', alpha=0.7, linewidth=1.5)

plt.xlabel('时间 $t$')
plt.ylabel('期权价格 $Y_t = u(t,X_t)$')
plt.title(f'{D}维看涨期权 - {model.mode}-{model.activation}')
plt.legend(['学习解', '精确解'])
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 详细对比图
plt.figure(figsize=figsize(1.0))

samples_plot = min(7, len(t_test))
for i in range(samples_plot):
    plt.plot(t_test[i, :, 0] * 100, Y_pred_reshaped[i, :] * 100, 'b-', alpha=0.6, linewidth=1.5)
    plt.plot(t_test[i, :, 0] * 100, Y_test[i, :] * 100, 'r--', alpha=0.6, linewidth=1.5)
    plt.plot(t_test[i, -1, 0] * 100, Y_test[i, -1] * 100, 'ko', markersize=4)

plt.plot([0], Y_test[0, 0] * 100, 'ks', markersize=6, label='$Y_0 = u(0,X_0)$')
plt.plot(t_test[0, -1, 0] * 100, Y_test[0, -1] * 100, 'ko', markersize=6, label='$Y_T = u(T,X_T)$')

plt.title(f'{D}维看涨期权 - {model.mode}-{model.activation}')
plt.legend(['学习解', '精确解', '$Y_0$', '$Y_T$'])
plt.xlabel('时间 $t$ (%)')
plt.ylabel('期权价格 $Y_t = u(t,X_t)$')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 误差分布图
plt.figure(figsize=figsize(1.0))
absolute_errors = np.abs(Y_test[:50] - Y_pred_reshaped[:50])
plt.hist(absolute_errors.flatten(), bins=30, alpha=0.7, color='blue', edgecolor='black')
plt.xlabel('绝对误差')
plt.ylabel('频数')
plt.title('预测误差分布')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 性能统计
print("\n=== 性能统计 ===")
print(f"设备类型: {device_type}")
print(f"批量大小: {M}")
print(f"时间步数: {N}")
print(f"训练迭代: {len(model.iteration)} 次")
print(f"最终损失: {model.training_loss[-1]:.6f}")
print(f"测试RMSE: {np.sqrt(errors.mean()):.6f}")

# 保存模型
model.save_model(f'call_option_model_{device_type}.pth')
print(f"模型已保存为: call_option_model_{device_type}.pth")

print("\n所有代码执行完成！")

# %% [markdown]
# ## 8. 多设备性能对比（可选）

# %%
def benchmark_model():
    """基准测试函数"""
    print("开始基准测试...")
    
    # 测试推理速度
    test_iterations = 10
    start_time = time.time()
    
    for i in range(test_iterations):
        t_test, W_test = model.fetch_minibatch()
        X_pred, Y_pred = model.predict(Xi, t_test, W_test)
        
        if device_type == 'tpu':
            xm.mark_step()  # TPU需要标记步骤
    
    inference_time = (time.time() - start_time) / test_iterations
    print(f"平均推理时间: {inference_time:.4f} 秒/次")
    
    # 测试训练速度
    if len(model.training_loss) > 10:
        print(f"最后10次平均损失: {np.mean(model.training_loss[-10:]):.6f}")

# 运行基准测试
benchmark_model()


使用MPS设备: mps
设备类型: mps
批量大小: 128
累积步数: 1
系统总内存: 16.00 GB
系统可用内存: 3.91 GB
系统使用率: 75.6%
开始第一阶段训练...
迭代: 0, 损失: 6.589e+01, Y0: 0.071, 时间: 2.33s, 学习率: 4.099e-05, 设备: MPS
迭代: 100, 损失: 1.449e+01, Y0: 0.067, 时间: 3.95s, 学习率: 9.674e-04, 设备: MPS
迭代: 200, 损失: 1.011e+01, Y0: 0.080, 时间: 3.75s, 学习率: 7.439e-04, 设备: MPS
迭代: 300, 损失: 1.026e+01, Y0: 0.153, 时间: 3.74s, 学习率: 4.063e-04, 设备: MPS
迭代: 400, 损失: 8.252e+00, Y0: 0.103, 时间: 4.03s, 学习率: 1.125e-04, 设备: MPS
第一阶段训练完成，总时间: 22.34s
开始第二阶段训练...
迭代: 400, 损失: 9.255e+00, Y0: 0.112, 时间: 0.04s, 学习率: 4.281e-07, 设备: MPS
迭代: 500, 损失: 8.810e+00, Y0: 0.115, 时间: 3.75s, 学习率: 8.346e-06, 设备: MPS
迭代: 600, 损失: 8.950e+00, Y0: 0.106, 时间: 3.62s, 学习率: 2.913e-06, 设备: MPS


: 