In [None]:
#尝试使用复数卷积
import torch
import torch.nn as nn
from torch.nn import Module, Conv1d
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as data
import os
import matplotlib.pyplot as plt
import thop
from torchinfo import summary
import time
import math
from einops import rearrange, repeat
import smtplib
from email.mime.text import MIMEText
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
seed = 1000
set_seed(seed)
use_cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')
print(device)
torch.cuda.empty_cache() 

class RealToComplex(nn.Module):
    """可训练的实数到复数转换层"""
    def __init__(self, seq_len):
        super().__init__()
        # 使用两个独立的线性层分别生成实部和虚部
        self.real_proj = nn.Linear(seq_len, seq_len)
        self.imag_proj = nn.Linear(seq_len, seq_len)
        
    def forward(self, x):
        # x 形状: [batch, 1, seq_len]
        x = x.squeeze(1)  # [batch, seq_len]
        real = self.real_proj(x)
        imag = self.imag_proj(x)
        return torch.complex(real, imag).unsqueeze(1)  # [batch, 1, seq_len] complex

class ComplexEncoder(nn.Module):
    def __init__(self, N, G):
        super().__init__()
        self.N = N
        self.G = G
        self.seq_len = 128  
        
        assert 256 % G == 0, "256 must be divisible by G"
        assert 64 % G == 0, "64 must be divisible by G"
        assert 32 % G == 0, "32 must be divisible by G"

        # 可训练的实数到复数转换器
        self.real_to_complex = RealToComplex(self.seq_len)

        self.encoder = nn.Sequential(
            ComplexConv1d(1, 256, 5, 1, 2),
            ComplexReLU(),
            ComplexConv1d(256, 64, 3, 1, 1, groups=G),
            ComplexReLU(),
            ComplexConv1d(64, 32, 3, 1, 1, groups=G),
        )
        
        self.shortcut = ComplexConv1d(1, 32, kernel_size=3, stride=1, padding=1)

        self.timedis1 = nn.Sequential(
            nn.ReLU(),
            TimeDistributed(nn.Linear(64, 2), True),
            nn.BatchNorm1d(N),
        )

    def forward(self, x):
        x = self.real_to_complex(x)
        # 通过复数卷积层
        out1 = self.encoder(x)
        out1_ori = self.shortcut(x)
        combined = out1 + out1_ori

        # 将复数转换为实数（拼接实部和虚部）
        real_part = combined.real
        imag_part = combined.imag
        tx_input = torch.cat([real_part, imag_part], dim=1)

        # 通过时间分布层
        tx = self.timedis1(
            tx_input.reshape(x.shape[0], 64, -1).transpose(1, 2)
        ).reshape(x.shape[0], 2, -1)

        # 功率归一化
        power_per_timestep = tx.pow(2).sum(dim=1)
        avg_power = power_per_timestep.mean(dim=1, keepdim=True)
        scale_factor = 1.0 / torch.sqrt(avg_power + 1e-8)
        tx_normalized = tx * scale_factor.unsqueeze(-1)

        return tx_normalized

def apply_complex_1d(fr, fi, x):
    real = fr(x.real) - fi(x.imag)
    imag = fr(x.imag) + fi(x.real)
    return torch.complex(real, imag)

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first 

    def forward(self, x):
        if len(x.size()) <= 2:
            return self.module(x)
        x_reshape = x.contiguous().view(-1, x.size(-1))  #(batch_size * 128 , 40)
        y = self.module(x_reshape) #(batch_size * 128 , 1)
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  #(batch_size, 128, 1)
        else:
            y = y.view(-1, x.size(1), y.size(-1))
        return y  

class ComplexReLU(nn.Module):
    def forward(self, x):
        return torch.complex(F.relu(x.real), F.relu(x.imag))
    
class ComplexConv1d(Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(ComplexConv1d, self).__init__()
        assert in_channels % groups == 0, 'in_channels must be divisible by groups'
        assert out_channels % groups == 0, 'out_channels must be divisible by groups'
        
        self.conv_r = Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.conv_i = Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )

    def forward(self, inp):
        return apply_complex_1d(self.conv_r, self.conv_i, inp)

class NaiveComplexBatchNorm1d(Module):
    """
    Naive approach to complex batch norm, perform batch norm independently on real and imaginary part.
    """
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
    ):
        super(NaiveComplexBatchNorm1d, self).__init__()
        self.bn_r = nn.BatchNorm1d(
            num_features, eps, momentum, affine, track_running_stats
        )
        self.bn_i = nn.BatchNorm1d(
            num_features, eps, momentum, affine, track_running_stats
        )

    def forward(self, inp):
        return torch.complex(self.bn_r(inp.real), self.bn_i(inp.imag))

class ComplexMultiScaleConvBlock(nn.Module):
    def __init__(self, K):
        super().__init__()
        self.conv3x1 = ComplexConv1d(1, K, kernel_size=3, padding=1)
        self.conv5x1 = ComplexConv1d(1, K, kernel_size=5, padding=2)
        self.conv7x1 = ComplexConv1d(1, K, kernel_size=7, padding=3)
        
        self.bn3 = NaiveComplexBatchNorm1d(K)
        self.bn5 = NaiveComplexBatchNorm1d(K)
        self.bn7 = NaiveComplexBatchNorm1d(K)
        
        self.relu = ComplexReLU()

    def forward(self, x):
        branch3 = self.relu(self.bn3(self.conv3x1(x)))
        branch5 = self.relu(self.bn5(self.conv5x1(x)))
        branch7 = self.relu(self.bn7(self.conv7x1(x)))
        return torch.cat([branch3, branch5, branch7], dim=1)

class SE(nn.Module):
    def __init__(self, Cin):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        # 压缩层：输入Cin*2（实部+虚部），输出Cin//r
        self.compress = nn.Conv1d(Cin*2, Cin, 1)  
        self.excitation = nn.Conv1d(Cin, Cin*2, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 拼接实部和虚部 [batch, 2*K, length]
        x_split = torch.cat([x.real, x.imag], dim=1)
        # 压缩和激励
        out = self.squeeze(x_split)          # [batch, 2*K, 1]
        out = self.relu(self.compress(out))  # [batch, K, 1]
        out = self.sigmoid(self.excitation(out))  # [batch, 2*K, 1]
        # 分离实虚部权重
        out_real, out_imag = torch.chunk(out, 2, dim=1)
        # 复数加权
        return torch.complex(x.real * out_real, x.imag * out_imag)

class PreAttn(nn.Module):
    def __init__(self, K):
        super(PreAttn, self).__init__()
        self.K = K
        
        self.complex_multiscale = ComplexMultiScaleConvBlock(K//2)  # K//2个复数通道
        self.complex_conv1x1 = ComplexConv1d(3*(K//2), K//2, kernel_size=1)
        self.bn = NaiveComplexBatchNorm1d(K//2)
        self.relu = ComplexReLU()
        self.shortcut = ComplexConv1d(1, K//2, kernel_size=1)
        self.attention = SE(K//2)
    
    def forward(self, x):
        # 转换为复数
        x_permuted = x.permute(0, 2, 1)  # [batch, length, 2]
        x_complex = torch.view_as_complex(x_permuted.contiguous())  # [batch, length]
        x_complex = x_complex.unsqueeze(1)  # [batch, 1, length]
        
        # 主路径
        out1 = self.complex_multiscale(x_complex)
        out1 = self.complex_conv1x1(out1)
        out1 = self.relu(self.bn(out1))
        
        # 注意力
        out2 = self.attention(out1)
        
        # 捷径路径
        out3 = self.shortcut(x_complex)
        
        # 残差连接
        complex_out = out2 + out3
        
        # 转换为实数
        real_out = torch.cat([complex_out.real, complex_out.imag], dim=1)
        
        return real_out

class TimeDistributed(nn.Module):
    def __init__(self, module, batch_first=False):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first 

    def forward(self, x):
        if len(x.size()) <= 2:
            return self.module(x)
        x_reshape = x.contiguous().view(-1, x.size(-1))  #(batch_size * 128 , 40)
        y = self.module(x_reshape) #(batch_size * 128 , 1)
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  #(batch_size, 128, 1)
        else:
            y = y.view(-1, x.size(1), y.size(-1))
        return y

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.num_experts = 28  # 专家数量
        self.top_k = 12         # 选择top专家
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, dim),
                nn.Dropout(dropout)
            ) for _ in range(self.num_experts)
        ])
        self.gate = nn.Linear(dim, self.num_experts)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, seq_len, d = x.shape
        x_flat = x.reshape(-1, d)  # (batch*seq_len, d)

        logits = self.gate(x_flat)  # (batch*seq_len, num_experts)
        topk_logits, topk_indices = logits.topk(k=self.top_k, dim=-1)  
        topk_gates = torch.softmax(topk_logits, dim=-1)               
        topk_gates = self.dropout(topk_gates)
        
        x_repeated = x_flat.repeat_interleave(self.top_k, dim=0)      # (batch*seq_len*6, d)
        expert_indices = topk_indices.flatten()                       # (batch*seq_len*6)

        # 初始化输出并聚合专家结果
        out_flat = torch.zeros_like(x_flat)
        for expert_id in range(self.num_experts):
            mask = expert_indices == expert_id
            if not mask.any():
                continue
            
            expert_input = x_repeated[mask]
            expert_output = self.experts[expert_id](expert_input)
            
            # 计算加权输出并映射回原始位置
            original_indices = torch.arange(batch_size * seq_len, device=x.device)
            original_indices = original_indices.repeat_interleave(self.top_k)[mask]
            weighted_output = expert_output * topk_gates.flatten()[mask].unsqueeze(1)
            out_flat.index_add_(0, original_indices, weighted_output)

        return out_flat.reshape(batch_size, seq_len, d)

cuda:0


In [None]:
class Attention(nn.Module):

    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = heads*dim_head
        project_out = not (heads == 1 and dim_head == dim) 

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1) 
        # b:batch_size  n:channel  h:heads  d:dim_head
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v) 
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class BinaryVoVUnit(nn.Module):
    """
    在模块末尾通过卷积将长度 L 从 135 降到 128：
    - 主分支：Conv1d(k=8, s=1, p=0) -> 128
    - 残差分支：AvgPool1d(k=8, s=1) 同样降到 128；必要时再用 1x1 Conv 对齐通道
    """
    def __init__(self, in_channels, out_channels, branch_channels, num_branches, use_residual):
        super().__init__()
        self.use_residual = use_residual

        # 分支
        self.branches = nn.ModuleList()
        ch_in = in_channels
        for _ in range(num_branches):
            self.branches.append(
                nn.Sequential(
                    nn.Conv1d(ch_in, branch_channels, kernel_size=3, padding=1),
                    nn.BatchNorm1d(branch_channels),
                    nn.ReLU(inplace=False)
                )
            )
            ch_in = branch_channels

        # 拼接 + 压缩
        self.concat_conv = nn.Sequential(
            nn.Conv1d(in_channels + num_branches * branch_channels, out_channels, kernel_size=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=False)
        )

        # —— 核心：长度压缩到 128（135 -> 128）
        self.len_reduce = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, kernel_size=8, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=False)
        )

        # 残差分支的长度对齐与通道对齐
        if use_residual:
            # 仅长度对齐：使用平均池化（不改变通道数）
            self.res_len = nn.AvgPool1d(kernel_size=8, stride=1)  # L_out = L_in - 7
            # 通道不一致时的投影
            self.res_proj = None
            if in_channels != out_channels:
                self.res_proj = nn.Sequential(
                    nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False),
                    nn.BatchNorm1d(out_channels)
                )

    def forward(self, x):
        identity = x

        outs = [x]
        for br in self.branches:
            x = br(x)
            outs.append(x)

        x = torch.cat(outs, dim=1)
        x = self.concat_conv(x)          # [N, C_out, 135]（示例）
        x = self.len_reduce(x)           # -> [N, C_out, 128]

        if self.use_residual:
            identity = self.res_len(identity)  # [N, C_in, 128]
            if self.res_proj is not None:
                identity = self.res_proj(identity)  # [N, C_out, 128]
            x = x + identity

        return x

       
class Decoder(nn.Module):
    def __init__(self, N, G, K):
        super(Decoder, self).__init__()
        self.N = N
        self.G = G
        self.K = K
        #[batch_size, 40, 128]
        self.decoder = Transformer(dim=N, depth=2, heads=8, dim_head=32, mlp_dim=N* 2, dropout=0.05) #[batch_size, 40, 128]

        #[batch_size, 128, 40]
        self.timedis2 = nn.Sequential(
            TimeDistributed(nn.Linear(K, 1), batch_first=True),
            nn.Sigmoid(),
        ) #[batch_size, 128, 1]

    #[batch_size, 40, 128]
    def forward(self, R):
        out = self.decoder(R).reshape(R.shape[0], self.K, -1).transpose(1, 2).contiguous() #[batch_size, 128, 40]
        out = self.timedis2(out).reshape(R.shape[0], 1, -1)#[batch_size, 1, 128]
        return out

class Net(nn.Module):
    """
    for SISO frequency-selective
    input: 128 × 1
    """
    def __init__(self):
        super(Net, self).__init__()
        self.transmitter = ComplexEncoder(128, 4)
        self.receiver = nn.Sequential(
            # nn.Conv1d(80, 40, kernel_size=1),
            BinaryVoVUnit(in_channels=70, out_channels=35, branch_channels=5, num_branches=3, use_residual=True),
            Decoder(128, 2, 35),
            )
        self.feature_extractor = PreAttn(70)

        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def channel(self, input_signals, SNR, H_r, H_i, batch_size):
        #[2, batch_size, 128]
        assert input_signals.dim() == 3
        real = input_signals[0]
        imag = input_signals[1]
        S = torch.mean(real ** 2 + imag ** 2)
        snr_linear = 10 ** (SNR / 10.0)
        noise_variance = S / (2 * snr_linear)

        h_r = torch.zeros(batch_size, H_r.shape[-1]).to(device)
        h_r = H_r
        h_i = torch.zeros(batch_size, H_i.shape[-1]).to(device)
        h_i = H_i

        def conv(x, h):
            # x:[batch_size, 128]
            # h:[batch_size, 8]
            h = h.flip(dims=[-1]).reshape(batch_size, 1, -1)  # [batch_size, 1, 8]
            # x:[1, batch_size, 128] h:[batch_size, 1, 8]
            y = F.conv1d(x.reshape(1, x.shape[0], -1), h, padding=(H_i.shape[-1] - 1), groups=x.shape[0])
            # y:[1, batch_size, 135]
            y = y.reshape(input_signals.shape[1], -1)  # y:[batch_size, 135]  —— 去掉原来的 [:, :x.shape[1]] 裁剪
            return y  # [batch_size, 135]


        out_r = conv(real, h_r) - conv(imag, h_i) 
        out_i = conv(imag, h_r) + conv(real, h_i)

        out = torch.stack((out_r, out_i), dim=0)
        noise = torch.sqrt(noise_variance) * torch.randn_like(out, device=device)
        out += noise
        return out

    def forward(self, x, snr, h_r, h_i, batch_size):
        #[batchsize,1,128]
        tx = self.transmitter(x)#[batch_size, 2, 128]
        xx = torch.cat((tx[:, 0, :], tx[:, 1, :])).reshape(2, x.shape[0], -1)#[2, batch_size, 128]
        rx = self.channel(xx, snr, h_r, h_i, batch_size)#[2, batch_size, 128]
        r = torch.cat((rx[0], rx[1]), 1).reshape(x.shape[0], 2, -1)#[batch_size, 2, 128]
        z1 = self.feature_extractor(r)
        #R = torch.einsum("bcl,bz->bczl", r, z1).view(r.shape[0], -1, r.shape[-1])
        out = self.receiver(z1)#[batch_size, 1, 128]
        return out


In [None]:
def H_sample(TDL_dB, delay, batch_size=None):
    powerTDL = 10 ** (torch.tensor(TDL_dB).to(device) / 10)

    if batch_size:
        H_r = torch.zeros(batch_size, delay[-1] + 1, device=device)
        H_i = torch.zeros(batch_size, delay[-1] + 1, device=device)
        H_r[:, delay] = torch.randn(batch_size, len(delay), device=device) * torch.sqrt(powerTDL.unsqueeze(0) / 2)
        H_i[:, delay] = torch.randn(batch_size, len(delay), device=device) * torch.sqrt(powerTDL.unsqueeze(0) / 2)

    else:
        H_r = torch.zeros(delay[-1] + 1, device=device) #[8]
        H_i = torch.zeros(delay[-1] + 1, device=device)
        H_r[delay] = torch.randn(len(delay)).to(device) * torch.sqrt(powerTDL / 2).to(device)
        H_i[delay] = torch.randn(len(delay)).to(device) * torch.sqrt(powerTDL / 2).to(device)
    return H_r, H_i #[batch_size, 8]

def send_notification_email(filename_prefix, best_loss_value, epoch):
    """发送训练结果通知邮件"""
    try:
        # 邮件配置
        sender = '@163.com'
        receiver = '8@qq.com'  # 可以改为你的接收邮箱
        password = 'BV'  # 授权码
        
        # 邮件内容
        subject = f"训练达标通知 - {filename_prefix}"
        body = (f"训练已达到目标值!\n\n"
               f"训练标识: {filename_prefix}\n"
               f"当前epoch: {epoch}\n"
               f"best_loss值: {best_loss_value:.2e}\n"
               f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        # 创建邮件
        msg = MIMEText(body, 'plain', 'utf-8')
        msg['From'] = sender
        msg['To'] = receiver
        msg['Subject'] = subject
        
        # 发送邮件
        with smtplib.SMTP_SSL('smtp.163.com', 465) as server:
            server.login(sender, password)
            server.sendmail(sender, receiver, msg.as_string())
        
        print("通知邮件已发送")
    except Exception as e:
        print(f"发送邮件失败: {e}")

def train(train_SNR, batch_size, batch_num, lr, epochs, TDL_dB, delay, sample_num):
    # 初始化参数（保持不变）
    train_SNR = train_SNR
    batch_size = batch_size
    batch_num = batch_num
    lr = lr
    epochs = epochs
    TDL_dB = TDL_dB
    delay = delay
    sample_num = sample_num
    ber_threshold = 1e-5
    threshold_reached = False
    # 生成文件名前缀
    filename_prefix = generate_filename_prefix(batch_size, batch_num, epochs, lr, sample_num, seed)

    # 创建日志文件
    log_filename = f'{filename_prefix}training_log.txt'
    with open(log_filename, 'w') as log_file:
        log_file.write(f"Training Log - {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        log_file.write(f"Parameters: train_SNR={train_SNR}, batch_size={batch_size}, batch_num={batch_num}, "
                      f"lr={lr}, epochs={epochs}, TDL_dB={TDL_dB}, delay={delay}, sample_num={sample_num}\n\n")

    net = Net()
    net.to(device)
    if os.path.exists('best.pth'):
        net.load_state_dict(torch.load('best.pth', map_location=lambda storage, loc: storage.cuda(0)))
        log_message = "Loaded pre trained model 'best.pth'"
        print(log_message)
        with open(log_filename, 'a') as log_file:
            log_file.write(log_message + "\n")
    else:
        log_message = "No pre trained model found. Starting training from scratch."
        print(log_message)
        with open(log_filename, 'a') as log_file:
            log_file.write(log_message + "\n")

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=6, 
                                                          verbose=True, threshold=0.0001, threshold_mode='rel', 
                                                          cooldown=0, min_lr=1e-6, eps=1e-08)
    criterion = torch.nn.BCELoss()

    train_losses = []
    val_losses = []
    best_loss = torch.tensor(np.inf)
    best_train_loss = torch.tensor(np.inf)
    best_ber = 1.0  
    ber_records = []

    # 早停机制
    epochs_no_improve = 0
    stop_patience = 25

    # 训练循环
    start_train_time = time.time()
    last_lr = lr
    for epoch in range(epochs):
        start_time = time.time()
        net.train()
        train_loss = 0.0
        
        # 训练过程
        for _ in range(batch_num):
            train_data = torch.tensor(np.random.randint(0, 2, [batch_size, 128])).float().to(device)
            s = train_data.reshape(batch_size, 1, -1)
            optimizer.zero_grad()
            BCE_loss = 0.0
            for _ in range(sample_num):
                H_r, H_i = H_sample(TDL_dB, delay, batch_size)
                r = net(s, train_SNR, H_r, H_i, batch_size)
                BCE_loss += criterion(r, s)
            BCE_loss.backward()
            optimizer.step()
            train_loss += BCE_loss.item()/sample_num

        train_loss /= batch_num
        train_losses.append(train_loss)

        # 验证过程
        net.eval()
        val_loss = 0.0
        total_errors = 0
        total_bits = 0
        with torch.no_grad():
            for _ in range(batch_num):
                val_data = torch.tensor(np.random.randint(0, 2, [batch_size, 128])).float().to(device)
                s = val_data.reshape(batch_size, 1, -1)
                H_r, H_i = H_sample(TDL_dB, delay, batch_size)
                r = net(s, train_SNR, H_r, H_i, batch_size)
                predicted = torch.round(r)
                errors = torch.sum(torch.abs(predicted - s)).item()
                val_loss += criterion(r, s).item()
                total_errors += errors
                total_bits += s.numel()
        
        val_loss /= batch_num
        val_losses.append(val_loss)
        ber = total_errors / total_bits
        ber_records.append(ber)
        
        

        # 计算时间信息        
        end_time = time.time()
        epoch_duration = end_time - start_time
        elapsed_time = time.time() - start_train_time
        average_time_per_epoch = elapsed_time / (epoch + 1)

        # 固定按 250 轮计算剩余时间，但确保不会小于 0
        remaining_epochs = max(250 - epoch - 1, 0)  # 防止负数
        remaining_time = average_time_per_epoch * remaining_epochs

        current_time = time.strftime("%H:%M", time.localtime())
        estimated_end_time = time.strftime("%H:%M", time.localtime(time.time() + remaining_time))

        # 准备日志信息
        log_message = (f"Epoch: {epoch}, "
                      f"Train loss: {train_loss:.8f}, "
                      f"Val loss: {val_loss:.8f}, "
                      f"Val BER: {ber:.9f}, "
                      f"Time: {epoch_duration:.1f}s, "
                      f"Current: {current_time}, "
                      f"ETA: {estimated_end_time}")
        
        # 打印并记录日志
        print(log_message)
        with open(log_filename, 'a') as log_file:
            log_file.write(log_message + "\n")

        # 保存最佳模型（三种指标）
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(net.state_dict(), f'{filename_prefix}_best_val_loss.pth')
            with open(log_filename, 'a') as log_file:
                log_file.write(f"Best val loss model saved at epoch {epoch}\n")
        
        if train_loss < best_train_loss:
            best_train_loss = train_loss
            torch.save(net.state_dict(), f'{filename_prefix}_best_train_loss.pth')
            with open(log_filename, 'a') as log_file:
                log_file.write(f"Best train loss model saved at epoch {epoch}\n")
        
        if ber < best_ber:
            best_ber = ber
            torch.save(net.state_dict(), f'{filename_prefix}_best.pth')
            with open(log_filename, 'a') as log_file:
                log_file.write(f"Best BER model saved at epoch {epoch}\n")
            
            # 重置早停计数器（只有当BER有提升时才重置）
            epochs_no_improve = 0
        
        # 新增 BER 阈值检查
        if ber < ber_threshold and not threshold_reached:  # 假设 ber_threshold = 1e-5
            threshold_reached = True
            log_message = f"BER首次达到阈值 {ber_threshold} (当前值: {ber:.2e})"
            print(log_message)
            with open(log_filename, 'a') as log_file:
                log_file.write(log_message + "\n")
            
            # 发送通知邮件，传入filename_prefix
            send_notification_email(filename_prefix, ber, epoch)

        # 早停逻辑（保持不变）
        else:
            epochs_no_improve += 1
        if epochs_no_improve == stop_patience:
            stop_message = 'Early stopping triggered!'
            print(stop_message)
            with open(log_filename, 'a') as log_file:
                log_file.write(stop_message + "\n")
            break

        # 定期保存图表（保持不变）
        if epoch % 5 == 4:
            # 保存损失曲线图
            plt.figure()
            plt.plot(train_losses, label='Train Loss')
            plt.plot(val_losses, label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.yscale('log')
            plt.title(f'Loss Curves (Epoch {epoch+1})')
            plt.legend()
            plt.grid(True)
            plt.savefig(f'{filename_prefix}_loss_epoch.png')
            plt.close()

            # 保存BER曲线图
            plt.figure()
            plt.plot(ber_records, label='BER')
            plt.xlabel('Epoch')
            plt.ylabel('BER')
            plt.yscale('log')
            plt.title(f'BER Curve (Epoch {epoch+1})')
            plt.legend()
            plt.grid(True)
            plt.savefig(f'{filename_prefix}_ber_epoch.png')
            plt.close()   
        
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        # 检查学习率是否发生变化
        if epoch == 0:
            last_lr = current_lr  # 初始化记录
            print(f"Initial Learning Rate: {current_lr:.2e}")
        elif current_lr != last_lr:
            print(f"\nLearning Rate changed from {last_lr:.2e} to {current_lr:.2e} at Epoch {epoch}")
            last_lr = current_lr  # 更新记录

    # 最终图表保存（保持不变）
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{filename_prefix}_loss_curve.png')
    plt.close()

    plt.figure()
    plt.plot(ber_records, label='BER')
    plt.xlabel('Epoch')
    plt.ylabel('BER')
    plt.yscale('log')
    plt.title('Bit Error Rate over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{filename_prefix}_val_BER.png')
    plt.close()

    # 训练结束信息
    end_message = f'Training completed at {time.strftime("%Y-%m-%d %H:%M:%S")}'
    print(end_message)
    with open(log_filename, 'a') as log_file:
        log_file.write("\n" + end_message + "\n")
        log_file.write(f"Total training time: {time.time() - start_train_time:.2f} seconds\n")
        log_file.write(f"Best validation loss: {best_loss:.8f}\n")
        log_file.write(f"Final BER: {ber_records[-1]:.9f}\n")

    print('Training finished')

def send_test_results_email(filename_prefix, SNR_range, BER_results):
    """发送测试结果邮件"""
    try:
        # 邮件配置
        sender = '@163.com'
        receiver =  # 可以改为你的接收邮箱
        password =  # 授权码
        # 创建结果表格
        result_table = "SNR(dB)\tBER\n"
        result_table += "-------\t-------\n"
        for snr, ber in zip(SNR_range, BER_results):
            result_table += f"{snr}\t{ber:.4e}\n"
        
        # 邮件内容
        subject = f"模型测试结果 - {filename_prefix}"
        body = (f"模型测试已完成!\n\n"
               f"测试标识: {filename_prefix}\n"
               f"测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n"
               f"测试结果:\n{result_table}\n"
               f"最佳BER: {min(BER_results):.4e} (在SNR {SNR_range[BER_results.index(min(BER_results))]}dB时)")
        
        # 创建邮件
        msg = MIMEText(body, 'plain', 'utf-8')
        msg['From'] = sender
        msg['To'] = receiver
        msg['Subject'] = subject
        
        # 发送邮件
        with smtplib.SMTP_SSL('smtp.163.com', 465) as server:
            server.login(sender, password)
            server.sendmail(sender, receiver, msg.as_string())
        
        print("测试结果邮件已发送")
    except Exception as e:
        print(f"发送测试结果邮件失败: {e}")


def test_model(net, SNR_range, filename_prefix, num_samples=40000000, batch_size=800, log_filename=None):
    """
    测试模型并保存BER结果到txt文件，同时发送邮件通知
    
    参数:
        net: 训练好的模型
        SNR_range: 要测试的SNR范围
        num_samples: 总样本数
        batch_size: 批处理大小
        log_filename: 指定日志文件路径(优先使用)
        filename_prefix: 如果log_filename为None，则用此前缀生成日志文件名
    """
    BER = []
    TDL_dB = [0, 0, 0, 0, 0, 0, 0, 0]
    delay = [0, 1, 2, 3, 4, 5, 6, 7]
    
    # 记录测试开始时间
    test_start_time = time.time()
    print(f"测试开始时间(北京时间): {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(test_start_time))}")
    
    # 确定日志文件名
    if log_filename is None:
        if filename_prefix is not None:
            log_filename = f'{filename_prefix}BER_results.txt'
        else:
            log_filename = 'BER_results.txt'
            print(f"警告: 未指定日志文件名，使用默认: {log_filename}")
    
    # 测试过程
    for idx, snr in enumerate(SNR_range):
        snr_start_time = time.time()
        total_errors = 0
        total_bits = 0
        
        # 计算总批次数量
        total_batches = (num_samples + batch_size - 1) // batch_size
        
        # 分批处理
        for batch_idx in range(total_batches):
            current_batch_size = min(batch_size, num_samples - batch_idx * batch_size)
            
            # 生成测试数据
            data = torch.tensor(np.random.randint(0, 2, [current_batch_size, 128])).float().to(device)
            s = data.reshape(current_batch_size, 1, -1)
            
            # 前向传播
            with torch.no_grad():
                H_r, H_i = H_sample(TDL_dB, delay, current_batch_size)
                r = net(s, snr, H_r, H_i, current_batch_size)
            
            # 计算误码
            predicted = torch.round(r)
            errors = torch.sum(torch.abs(predicted - s)).item()
            total_errors += errors
            total_bits += s.numel()
            
            current_time = time.time()
            
            update_time_prediction = False

            if idx == 0 and (batch_idx + 1 in {1, 200, 400, 600, 800} or batch_idx + 1 == total_batches):
                update_time_prediction = True

            if idx > 0 and (batch_idx + 1 in {1, 500, 1000, 1500, 2000} or batch_idx + 1 == total_batches):
                update_time_prediction = True
            
            # 更新时间预测
            if update_time_prediction:
                elapsed = current_time - test_start_time
                completed_samples = idx * num_samples + batch_idx * batch_size + current_batch_size
                total_samples = len(SNR_range) * num_samples
                completed = completed_samples / total_samples
                
                if completed > 0:
                    total_estimated_time = elapsed / completed
                    remaining_time = total_estimated_time - elapsed
                    finish_time = current_time + remaining_time
                    
                    # 转换为北京时间
                    finish_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(finish_time))
                    current_time_str = time.strftime("%H:%M:%S", time.localtime())
                    
                    # 打印进度信息
                    print(f"\n[进度更新 {current_time_str}]")
                    print(f"当前进度: {completed*100:.1f}% | 已完成 {idx+1}/{len(SNR_range)} SNR点")
                    print(f"当前SNR: {snr}dB | 当前批次: {batch_idx+1}/{total_batches}")
                    print(f"预计剩余时间: {remaining_time/60:.1f}分钟 | 总预计耗时: {total_estimated_time/60:.1f}分钟")
                    print(f"最新预计结束时间(北京时间): {finish_time_str}")
        
        # 计算BER
        ber = total_errors / total_bits
        BER.append(ber)
        
        # 打印当前SNR点完成信息
        snr_time = time.time() - snr_start_time
        print(f"\nSNR点 {snr:3d} dB 测试完成:")
        print(f"BER: {ber:.4e} | 当前SNR点耗时: {snr_time/60:.1f}分钟")
        print(f"平均每批次时间: {snr_time/total_batches:.3f}s")
    
    # 测试实际结束时间
    test_end_time = time.time()
    total_test_time = test_end_time - test_start_time
    print(f"\n测试实际结束时间(北京时间): {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(test_end_time))}")
    print(f"总测试耗时: {total_test_time/3600:.2f}小时 ({total_test_time/60:.1f}分钟)")
    print(f"平均每个SNR点耗时: {total_test_time/len(SNR_range)/60:.1f}分钟")
    
    # 保存BER结果到txt文件
    with open(log_filename, 'w') as f:
        for snr, ber in zip(SNR_range, BER):
            f.write(f"{snr} {ber}\n")
    
    # 发送测试结果邮件
    if filename_prefix is not None:
        send_test_results_email(filename_prefix, SNR_range, BER)
    else:
        print("未指定filename_prefix，跳过发送邮件")
    
    return BER

def plot_ber(SNR_range, BER, batch_size, epochs, lr):
    filename_prefix = generate_filename_prefix(batch_size, batch_num, epochs, lr, sample_num, seed)
    plt.figure()
    
    # LS
    LS_BER = [0.5, 0.35, 0.1, 0.0055, 3e-4]
    # MMSE
    MMSE_BER = [0.45, 0.3, 0.04, 1.5e-3, 5.5e-5]
    # E2E
    E2E_BER = [0.3, 0.15, 0.03, 1.5e-3, 1e-4]
    # E2E-GAN
    GAN_E2E_BER = [0.26, 0.12, 0.018, 6.6e-4, 3.3e-5]
    # E2E-GT-noDense
    GT_E2E_noDense_BER = [0.133296875, 0.042453125, 0.007496875, 0.000915625, 9.0625e-05]
    # E2E-GT-noGhost
    GT_E2E_noGhost_BER = [0.17920625, 0.06485, 0.010740625, 0.00100625, 0.000134375]
    
    # LS
    plt.semilogy(SNR_range, LS_BER, marker='s', label='LS', linestyle='-', color='purple')
    # MMSE
    plt.semilogy(SNR_range, MMSE_BER, marker='o', label='MMSE', linestyle='-', color='blue')
    # E2E
    plt.semilogy(SNR_range, E2E_BER, marker='o', label='E2E', linestyle='-', color='magenta', markerfacecolor='none')
    # E2E-GAN
    plt.semilogy(SNR_range, GAN_E2E_BER, marker='s', label='E2E-GAN', linestyle='-', color='cyan', markerfacecolor='none')
    # GT-E2E-noDense
    plt.semilogy(SNR_range, GT_E2E_noDense_BER, marker='^', label='GT-E2E-no Dense', linestyle='--', color='red')
    # GT-E2E-noGhost
    plt.semilogy(SNR_range, GT_E2E_noGhost_BER, marker='v', label='GT-E2E-no Ghost', linestyle='--', color='red')

    # GT-E2E
    plt.semilogy(SNR_range, BER, marker='+', label='GT-moeE2E', linestyle='-', color='red')
    
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('BER vs SNR')
    plt.grid(True)
    plt.ylim(1e-5, 1e-0)
    plt.legend()
    plt.savefig(f'{filename_prefix}_BER_curve.png')
    plt.show()

def generate_filename_prefix(batch_size, batch_num, epochs, lr, sample_num, seed):
    base_dir = "./models"
    subfolder_name = f"28-12-lay2dim32-7035change_decoderbs{batch_size}_bn{batch_num}_ep{epochs}_lr{lr}_sm{sample_num}_seed{seed}/"
    subfolder_path = os.path.join(base_dir, subfolder_name)
    
    # 创建文件夹（如果不存在）
    os.makedirs(subfolder_path, exist_ok=True)
        # 确保路径末尾有 '/'
    if not subfolder_path.endswith(os.sep):  # os.sep 是系统分隔符（'/' 或 '\'）
        subfolder_path += os.sep
    
    return subfolder_path

if __name__ == "__main__":
    TDL_dB = [0, 0, 0, 0, 0, 0, 0, 0]
    delay = [0, 1, 2, 3, 4, 5, 6, 7]
    train_SNR =15.0
    batch_size = 800
    batch_num = 300
    lr = 1e-3
    epochs = 1000
    sample_num = 1
    SNR_testrange = [20]
    # 生成文件名前缀
    filename_prefix = generate_filename_prefix(batch_size, batch_num, epochs, lr, sample_num, seed)
    
#      训练模型
    train(train_SNR, batch_size, batch_num, lr, epochs, TDL_dB, delay, sample_num)
    #     # 准备测试数据
    test_data = torch.tensor(np.random.randint(0, 2, [1, 128])).float().to(device)
    test_data = test_data.reshape(1, 1, -1)
    h_r, h_i = H_sample(TDL_dB, delay, 1)
    
    # 测试最佳BER模型
    # 测试最佳BER模型
    print("\n" + "="*50)
    print("Testing Best BER Model")
    net_ber = Net()
    net_ber.to(device)
    net_ber.load_state_dict(torch.load(f'{filename_prefix}_best.pth', 
                                     map_location=lambda storage, loc: storage.cuda(0)))
    net_ber.eval()
    BER_ber = test_model(net_ber, SNR_testrange, filename_prefix)
    print("BER results:", BER_ber)

    # 模型统计信息（以最佳BER模型为例）
    print("\n" + "="*50)
    print("Model Summary (Best BER Model)")
    model_summary = summary(net_ber, input_data=(test_data, 15, h_r, h_i, 1), verbose=0)
    print(model_summary)
    
    flops, params = thop.profile(net_ber, inputs=(test_data, 15, h_r, h_i, 1), verbose=False)
    flops, params = thop.clever_format([flops, params], "%.3f")
    print("thop-flops & trainable parameters:", [flops, params])
    


No pre trained model found. Starting training from scratch.




Epoch: 0, Train loss: 0.60060253, Val loss: 0.46909063, Val BER: 0.233184408, Time: 68.5s, Current: 00:04, ETA: 04:49
Initial Learning Rate: 1.00e-03
Epoch: 1, Train loss: 0.32404624, Val loss: 0.19494186, Val BER: 0.082079590, Time: 67.1s, Current: 00:06, ETA: 04:46
Epoch: 2, Train loss: 0.11422776, Val loss: 0.04642085, Val BER: 0.016456641, Time: 67.1s, Current: 00:07, ETA: 04:45
Epoch: 3, Train loss: 0.03835401, Val loss: 0.02097795, Val BER: 0.007225033, Time: 67.2s, Current: 00:08, ETA: 04:45
Epoch: 4, Train loss: 0.02263247, Val loss: 0.01354838, Val BER: 0.004598047, Time: 67.8s, Current: 00:09, ETA: 04:46
Epoch: 5, Train loss: 0.01585298, Val loss: 0.00924480, Val BER: 0.003113151, Time: 68.1s, Current: 00:10, ETA: 04:46
Epoch: 6, Train loss: 0.01207658, Val loss: 0.00698577, Val BER: 0.002344727, Time: 68.2s, Current: 00:11, ETA: 04:47
Epoch: 7, Train loss: 0.00948801, Val loss: 0.00537842, Val BER: 0.001794010, Time: 68.1s, Current: 00:12, ETA: 04:47
Epoch: 8, Train loss: 0.