In [None]:
! pip install torch msprime matplotlib

In [21]:
import torch
import torch.nn as nn
import numpy as np
import math
import msprime
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader
import time
import os
import multiprocessing as mp
from tqdm import tqdm
from simple import SimpleCoalNN, ImprovedCoalNN, CoalescentDataset, predict_tmrca, plot_results

ImportError: cannot import name 'predict_tmrca' from 'simple' (/Users/larry/Lab/DLCoal/simple.py)

In [11]:
def generate_batch(args):
    """用于并行处理的单个批次生成函数"""
    batch_id, num_samples, sequence_length, Ne, mut_rate_range, rec_rate_range = args
    np.random.seed(42 + batch_id)  # 确保不同进程有不同的随机种子
    
    all_haplotypes = []
    all_tmrca = []
    
    # 创建不同的人口历史场景
    demographic_models = [
        # 常数大小人口
        msprime.Demography(),
        
        # 人口扩张
        msprime.Demography.isolated_model([1.0], 
            growth_rate=[0.01],
            initial_size=[Ne]),
            
        # 人口瓶颈
        msprime.Demography.population_split(time=0.1*Ne, 
            initial_size=[Ne*0.5, Ne], 
            growth_rate=[0, 0], 
            demographics_id="bottleneck")
    ]
    
    for i in range(num_samples):
        # 随机选择人口模型
        demog_model = np.random.choice(demographic_models)
        
        # 随机调整变异率和重组率（增加多样性）
        mut_rate = np.random.uniform(*mut_rate_range)
        rec_rate = np.random.uniform(*rec_rate_range)
        
        # 随机有效群体大小
        effective_size = Ne * np.random.uniform(0.8, 1.2)
        
        # 生成树序列
        ts = msprime.simulate(
            sample_size=2,
            Ne=effective_size,
            length=sequence_length,
            mutation_rate=mut_rate,
            recombination_rate=rec_rate,
            record_full_arg=True,
            demography=demog_model if demog_model.populations else None
        )
        
        # 提取单倍型
        haplotypes = np.zeros((2, sequence_length), dtype=np.int8)
        for variant in ts.variants():
            pos = int(variant.position)
            if pos < sequence_length:
                haplotypes[:, pos] = variant.genotypes
        
        # 提取TMRCA
        tmrca = np.zeros(sequence_length)
        for tree in ts.trees():
            left, right = int(tree.interval.left), int(min(tree.interval.right, sequence_length))
            tmrca[left:right] = tree.tmrca(0, 1) * 2 * effective_size
        
        all_haplotypes.append(haplotypes)
        all_tmrca.append(tmrca)
    
    return np.array(all_haplotypes), np.array(all_tmrca)

In [12]:
def parallel_generate_data(total_samples, sequence_length, Ne=10000, 
                          mut_rate_range=(0.5e-8, 5e-8), 
                          rec_rate_range=(0.5e-8, 5e-8),
                          batch_size=1000, num_processes=None):
    """
    并行生成大量模拟数据
    
    参数:
    total_samples: 总样本数
    sequence_length: 序列长度
    Ne: 有效群体大小
    mut_rate_range: 突变率范围 (min, max)
    rec_rate_range: 重组率范围 (min, max)
    batch_size: 每个批次生成的样本数
    num_processes: 并行进程数，默认为可用CPU核心数
    
    返回:
    haplotypes: 单倍型数据
    tmrca: 对应的TMRCA值
    """
    if num_processes is None:
        num_processes = mp.cpu_count()
    
    print(f"使用{num_processes}个进程并行生成数据")
    
    # 计算需要多少批次
    num_batches = (total_samples + batch_size - 1) // batch_size
    last_batch_size = total_samples - (num_batches - 1) * batch_size
    
    # 准备参数列表
    args_list = []
    for i in range(num_batches - 1):
        args_list.append((i, batch_size, sequence_length, Ne, mut_rate_range, rec_rate_range))
    # 添加最后一个可能不完整的批次
    if last_batch_size > 0:
        args_list.append((num_batches - 1, last_batch_size, sequence_length, Ne, mut_rate_range, rec_rate_range))
    
    # 并行生成数据
    all_haplotypes = []
    all_tmrca = []
    
    with mp.Pool(processes=num_processes) as pool:
        results = list(tqdm(pool.imap(generate_batch, args_list), 
                           total=len(args_list), 
                           desc="生成数据批次"))
    
    # 收集结果
    for haps, tmrcas in results:
        all_haplotypes.append(haps)
        all_tmrca.append(tmrcas)
    
    return np.concatenate(all_haplotypes), np.concatenate(all_tmrca)

In [13]:
def save_data_in_chunks(haplotypes, tmrca, output_prefix, chunk_size=10000):
    """将大型数据集分块保存到磁盘"""
    num_samples = haplotypes.shape[0]
    num_chunks = (num_samples + chunk_size - 1) // chunk_size
    
    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, num_samples)
        
        chunk_data = {
            'haplotypes': haplotypes[start_idx:end_idx],
            'tmrca': tmrca[start_idx:end_idx]
        }
        
        chunk_file = f"{output_prefix}_chunk_{i}.npz"
        np.savez_compressed(chunk_file, **chunk_data)
        print(f"保存数据块 {i+1}/{num_chunks} 到 {chunk_file}")

In [14]:
def load_data_from_chunks(input_prefix, num_chunks):
    """从多个数据块文件加载数据"""
    all_haplotypes = []
    all_tmrca = []
    
    for i in range(num_chunks):
        chunk_file = f"{input_prefix}_chunk_{i}.npz"
        if os.path.exists(chunk_file):
            data = np.load(chunk_file)
            all_haplotypes.append(data['haplotypes'])
            all_tmrca.append(data['tmrca'])
            print(f"加载数据块 {i+1}/{num_chunks} 从 {chunk_file}")
        else:
            print(f"警告: 数据块 {chunk_file} 未找到")
    
    return np.concatenate(all_haplotypes), np.concatenate(all_tmrca)

In [15]:
def train_with_early_stopping(model, train_loader, val_loader, max_epochs, 
                             patience=5, learning_rate=0.001, 
                             device='cuda', checkpoint_dir='checkpoints'):
    """
    训练模型，带有提前停止和检查点保存功能
    
    参数:
    model: 要训练的模型
    train_loader: 训练数据加载器
    val_loader: 验证数据加载器
    max_epochs: 最大训练轮数
    patience: 提前停止的耐心值
    learning_rate: 学习率
    device: 训练设备
    checkpoint_dir: 检查点保存目录
    
    返回:
    训练好的模型、训练历史和最佳模型状态
    """
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    criterion = nn.MSELoss()
    
    history = {
        'train_loss': [],
        'val_loss': []
    }
    
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    
    for epoch in range(max_epochs):
        start_time = time.time()
        
        # 训练阶段
        model.train()
        train_loss = 0.0
        
        for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epochs} [Train]")):
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # 定期打印进度
            if batch_idx % 100 == 0:
                print(f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        train_loss /= len(train_loader)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{max_epochs} [Val]"):
                inputs = batch['input'].to(device)
                targets = batch['target'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        # 更新学习率
        scheduler.step(val_loss)
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Time: {epoch_time:.2f}s')
        
        # 保存中间检查点
        if (epoch + 1) % 5 == 0:
            checkpoint = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'focus_input_size': model.focus_input_size if hasattr(model, 'focus_input_size') else None,
                'input_features': 7,  # 我们现在使用7个特征
                'root_time': model.root_time if hasattr(model, 'root_time') else None,
            }
            torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = {
                'model': model.state_dict(),
                'focus_input_size': model.focus_input_size if hasattr(model, 'focus_input_size') else None,
                'input_features': 7,  # 我们现在使用7个特征
                'root_time': model.root_time if hasattr(model, 'root_time') else None,
                'epoch': epoch
            }
            # 保存最佳模型
            torch.save(best_model_state, os.path.join(checkpoint_dir, 'best_model.pt'))
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"提前停止! {patience}轮验证损失没有改善")
                break
    
    # 加载最佳模型
    model.load_state_dict(best_model_state['model'])
    
    return model, history, best_model_state

In [16]:
if __name__ == "__main__":
    # 设置随机种子以便结果可重复
    torch.manual_seed(42)
    np.random.seed(42)
    
    # 检测CUDA是否可用
    device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
    print(f'使用设备: {device}')
    
    # 极大数据生成参数 
    train_samples = 200000  # 20万个样本用于训练
    val_samples = 20000    # 2万个样本用于验证
    test_samples = 1000   # 1万个样本用于测试
    sequence_length = 20000  # 每个序列2000个位点
    chunk_size = 10000     # 每个数据块大小
    
    # 检查是否已经有保存的数据
    train_chunks = (train_samples + chunk_size - 1) // chunk_size
    val_chunks = (val_samples + chunk_size - 1) // chunk_size
    test_chunks = (test_samples + chunk_size - 1) // chunk_size
    
    data_dir = 'large_data'
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    train_prefix = os.path.join(data_dir, 'train')
    val_prefix = os.path.join(data_dir, 'val')
    test_prefix = os.path.join(data_dir, 'test')
    
    if all(os.path.exists(f"{train_prefix}_chunk_{i}.npz") for i in range(train_chunks)):
        print("加载已存在的训练数据...")
        train_haplotypes, train_tmrca = load_data_from_chunks(train_prefix, train_chunks)
    else:
        print("生成新的训练数据...")
        train_haplotypes, train_tmrca = parallel_generate_data(
            train_samples, 
            sequence_length,
            mut_rate_range=(0.5e-8, 5e-8),  # 更大的突变率范围
            rec_rate_range=(0.5e-8, 5e-8),
            batch_size=5000,
            num_processes=min(16, mp.cpu_count())  # 限制使用的进程数
        )
        save_data_in_chunks(train_haplotypes, train_tmrca, train_prefix, chunk_size)
    
    if all(os.path.exists(f"{val_prefix}_chunk_{i}.npz") for i in range(val_chunks)):
        print("加载已存在的验证数据...")
        val_haplotypes, val_tmrca = load_data_from_chunks(val_prefix, val_chunks)
    else:
        print("生成新的验证数据...")
        val_haplotypes, val_tmrca = parallel_generate_data(
            val_samples, 
            sequence_length,
            mut_rate_range=(0.5e-8, 5e-8),
            rec_rate_range=(0.5e-8, 5e-8),
            batch_size=5000,
            num_processes=min(16, mp.cpu_count())
        )
        save_data_in_chunks(val_haplotypes, val_tmrca, val_prefix, chunk_size)
    
    if all(os.path.exists(f"{test_prefix}_chunk_{i}.npz") for i in range(test_chunks)):
        print("加载已存在的测试数据...")
        test_haplotypes, test_tmrca = load_data_from_chunks(test_prefix, test_chunks)
    else:
        print("生成新的测试数据...")
        test_haplotypes, test_tmrca = parallel_generate_data(
            test_samples, 
            sequence_length,
            mut_rate_range=(0.5e-8, 5e-8),
            rec_rate_range=(0.5e-8, 5e-8),
            batch_size=5000,
            num_processes=min(16, mp.cpu_count())
        )
        save_data_in_chunks(test_haplotypes, test_tmrca, test_prefix, chunk_size)
    
    print(f'训练样本: {train_haplotypes.shape[0]}')
    print(f'验证样本: {val_haplotypes.shape[0]}')
    print(f'测试样本: {test_haplotypes.shape[0]}')
    print(f'序列长度: {sequence_length}')
    
    # 创建数据集
    window_size = 30  # 增大局部特征的窗口大小
    train_dataset = CoalescentDataset(train_haplotypes, train_tmrca, window_size=window_size)
    val_dataset = CoalescentDataset(val_haplotypes, val_tmrca, window_size=window_size)
    test_dataset = CoalescentDataset(test_haplotypes, test_tmrca, window_size=window_size)
    
    # 创建数据加载器，使用较大的batch size和更多的worker
    batch_size = 256
    num_workers = min(8, os.cpu_count() or 1)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                           num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            num_workers=num_workers)
    
    # 检查点存储目录
    checkpoint_dir = 'large_checkpoints'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    # 检查是否有保存的模型
    model_path = os.path.join(checkpoint_dir, 'best_model.pt')
    if os.path.exists(model_path) and os.path.getsize(model_path) > 0:
        print(f'加载已训练模型 {model_path}')
        checkpoint = torch.load(model_path, map_location=device)
        model = ImprovedCoalNN(focus_input_size=sequence_length,
                               input_features=7,  # 使用7个特征
                               hidden_dims=[64, 128, 256, 512, 1024],  # 更大的网络
                               kernel_sizes=[9, 7, 5, 3, 3],  # 更大的感受野
                               num_residual_blocks=3)
        model.load_state_dict(checkpoint['model'])
    else:
        # 创建更大的模型
        model = ImprovedCoalNN(focus_input_size=sequence_length,
                               input_features=7,  # 使用7个特征
                               hidden_dims=[64, 128, 256, 512, 1024],  # 更大的网络
                               kernel_sizes=[9, 7, 5, 3, 3],  # 更大的感受野
                               num_residual_blocks=3)
        
        # 训练模型
        print('开始训练大规模CoalNN模型...')
        model, history, best_model = train_with_early_stopping(
            model, 
            train_loader, 
            val_loader, 
            max_epochs=100,  # 增加训练轮数
            patience=10,     # 提前停止的耐心值
            learning_rate=0.0005,  # 降低学习率以更好地收敛
            device=device,
            checkpoint_dir=checkpoint_dir
        )
        
        # 绘制损失曲线
        plt.figure(figsize=(12, 8))
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss (Large Scale)')
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(checkpoint_dir, 'large_scale_loss.png'), dpi=300)
        plt.show()
    
    # 在测试集上评估模型
    print('在大规模测试集上评估模型...')
    from torch.utils.data import Subset
    
    # 分批评估以避免内存问题
    batch_size = 1000
    all_predictions = []
    all_targets = []
    
    for i in range(0, len(test_dataset), batch_size):
        end_idx = min(i + batch_size, len(test_dataset))
        subset = Subset(test_dataset, range(i, end_idx))
        subset_loader = DataLoader(subset, batch_size=128, num_workers=num_workers)
        
        print(f"评估子集 {i}-{end_idx}...")
        pred, targ = evaluate_model(model, subset_loader, device=device)
        all_predictions.append(pred)
        all_targets.append(targ)
    
    # 合并结果
    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)
    
    # 可视化结果
    print('可视化大规模结果...')
    mse, mae, r2 = plot_results(predictions, targets, log_transform=True)
    
    print(f'大规模优化版CoalNN评估完成!')
    print(f'MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}')


使用设备: cpu
生成新的训练数据...


NameError: name 'mp' is not defined