In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

import sys
sys.path.append('../../wuchengzhou')
import sagan

In [None]:
def preprocess(wave, flux, wave_center):
    # 截取局部窗口
    start = np.argmin(np.abs(wave - (wave_center - 50)))  # ±50Å范围
    end = np.argmin(np.abs(wave - (wave_center + 50)))
    
    # 重采样到固定点数
    local_wave = wave[start:end]
    local_flux = flux[start:end]
    interp_fn = interp1d(local_wave, local_flux, kind='cubic')
    new_wave = np.linspace(wave_center-50, wave_center+50, 200)  # 固定200点
    new_flux = interp_fn(new_wave)
    
    # 标准化
    new_flux = (new_flux - np.median(new_flux)) / np.std(new_flux)
    return new_wave, new_flux

In [3]:
class DynamicSpectrumEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # 使用全卷积架构 + 全局池化
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(1, 32, 5, padding='same'),  # 保持空间维度
            nn.ReLU(),
            nn.Conv1d(32, 64, 3, padding='same'),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(100)  # 将任意长度压缩到100个特征点
        )
        
        # 中心波长编码器
        self.wave_encoder = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 32)
        )

    def forward(self, spectrum, wave_center):
        # spectrum: (batch, 1, N_points)
        # wave_center: (batch, 1)
        
        # 光谱特征提取
        spec_feat = self.feature_extractor(spectrum)  # (batch, 64, 100)
        spec_feat = spec_feat.view(spec_feat.size(0), -1)  # (batch, 64*100)
        
        # 中心波长特征
        wave_feat = self.wave_encoder(wave_center)  # (batch, 32)
        
        # 特征融合
        combined = torch.cat([spec_feat, wave_feat], dim=1)  # (batch, 64*20+32)
        return combined

In [None]:
def gaussian_profile(wavelength, amp, sigma, dv, wave_center):
    c = 3e5  # 光速 km/s
    delta_lambda = (dv / c) * wave_center
    center = wave_center + delta_lambda
    return amp * torch.exp(-0.5 * ((wavelength - center) / sigma)**2)

def compute_loss(model, batch, device):
    # 解包批次数据
    wave, flux, target_params = batch
    wave = wave.to(device)
    flux = flux.to(device)
    
    # 模型预测
    components_prob, amps_pred, sigma_pred, dv_pred = model(wave.unsqueeze(1))
    
    # 重建光谱
    pred_flux = torch.zeros_like(flux)
    for i in range(model.max_components):
        mask = (target_params['n_components'] > i).float()
        amp = amps_pred[:, i] * components_prob[:, i] * mask
        pred_flux += gaussian_profile(wave, amp, 
                                    sigma_pred.unsqueeze(1),
                                    dv_pred.unsqueeze(1),
                                    target_params['wave_center'])
    
    # 重建损失
    reconstruction_loss = F.mse_loss(pred_flux, flux)
    
    # 参数约束损失（示例：sigma最小约束）
    sigma_constraint = torch.mean(1/(sigma_pred + 1e-3))
    
    # 组件存在损失
    component_loss = F.binary_cross_entropy(
        components_prob,
        target_params['component_mask'].to(device)
    )
    
    total_loss = reconstruction_loss + 0.1*sigma_constraint + component_loss
    return total_loss

In [None]:
class NarrowLinePredictor(nn.Module):
    def __init__(self):
        super().__init__()
        # 共享参数预测层
        self.shared_predictor = nn.Sequential(
            nn.Linear(64*20+32, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # 输出[log_sigma, dv]
        )
        
        # 各线独立参数
        self.line_specific = nn.ModuleDict({
            'Halpha': nn.Linear(64*20+32, 2),
            'Hbeta': nn.Linear(64*20+32, 2),
            # ...其他谱线
        })

    def forward(self, combined_feat, line_type):
        # 共享参数
        shared_params = self.shared_predictor(combined_feat)
        sigma = torch.exp(shared_params[:, 0])  # 确保正数
        dv = shared_params[:, 1]
        
        # 各线独立参数
        amp = F.softplus(self.line_specific[line_type](combined_feat))
        return {'sigma': sigma, 'dv': dv, 'amp': amp}

In [None]:
from torch.utils.data import Dataset, DataLoader

class SpectrumDataset(Dataset):
    def __init__(self, n_samples=1000):
        # 实现数据生成逻辑
        pass

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GaussianPredictor().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):
    for batch in DataLoader(SpectrumDataset(), batch_size=32):
        loss = compute_loss(model, batch, device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def hybrid_fitting(wave, flux, line_centers):
    # 使用模型预测初始参数
    init_params = predict_parameters(model, wave, flux, line_centers)
    
    # 转换为传统拟合需要的参数格式
    params_init = []
    for center in line_centers:
        params = init_params[center]
        for amp in params['amps']:
            params_init.extend([amp, params['sigma'], center + params['dv']])
    
    # 运行传统最小二乘拟合
    final_params = run_levmar_fitting(wave, flux, params_init)
    return final_params