In [None]:
import os
concat_path = "XTT22_train.fa"


In [None]:
full_fasta_path = os.path.abspath(concat_path)
output_dir = os.path.abspath("preprocessed_data")
output_yaml = f"""
- datapaths: ["{full_fasta_path}"]
  output_dir: "{output_dir}"
  output_prefix: XTT22_train
  train_split: 0.9
  valid_split: 0.05
  test_split: 0.05
  overwrite: True
  embed_reverse_complement: true
  random_reverse_complement: 0.0
  random_lineage_dropout: 0.0
  include_sequence_id: false
  transcribe: "back_transcribe"
  force_uppercase: false
  indexed_dataset_dtype: "uint8"
  tokenizer_type: "Byte-Level"
  vocab_file: null
  vocab_size: null
  merges_file: null
  pretrained_tokenizer_model: null
  special_tokens: null
  fast_hf_tokenizer: true
  append_eod: true
  enforce_sample_length: null
  ftfy: false
  workers: 1
  preproc_concurrency: 100000
  chunksize: 25
  drop_empty_sequences: true
  nnn_filter: false  # If you split your fasta on NNN (in human these are contigs), then you should set this to true.
  seed: 12342  # Not relevant because we are not using random reverse complement or lineage dropout.
"""
with open("preprocess_config.yaml", "w") as f:
    print(output_yaml, file=f)


In [None]:
!preprocess_evo2 --config preprocess_config.yaml


In [None]:
!ls -lh preprocessed_data/


In [None]:
!cp /workspace/hyena_modified.py /usr/local/lib/python3.12/dist-packages/nemo/collections/llm/gpt/model/hyena.py


In [None]:
!evo2_convert_to_nemo2 \
  --model-path /workspace/savanna_evo2_7b/savanna_evo2_7b.pt \
  --model-size 7b --output-dir nemo2_evo2_7b


In [None]:
# ==================== NCCL超时问题解决方案 ====================
import os
import subprocess
import time
from datetime import datetime

# 1. 备份和替换训练脚本
print("🔧 备份并替换训练脚本...")
!cp /usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/train.py /usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/train.py.backup 2>/dev/null || echo "备份文件已存在或路径不存在"
!cp /workspace/bionemo_train.py /usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/train.py 2>/dev/null || echo "自定义训练脚本不存在，使用默认版本"

# 2. 设置NCCL和分布式环境变量
print("🔧 配置NCCL超时和优化参数...")

# NCCL超时设置 - 增加到2小时
os.environ['NCCL_TIMEOUT'] = '7200'  # 2小时超时
os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'  # 使用新的环境变量名
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'  # 使用新的环境变量名
os.environ['NCCL_DEBUG'] = 'INFO'  # 启用详细调试信息

# PyTorch分布式超时设置
os.environ['TORCH_DISTRIBUTED_TIMEOUT'] = '7200'  # PyTorch分布式超时
os.environ['TORCH_NCCL_TRACE_BUFFER_SIZE'] = '1024'  # 启用NCCL跟踪

# 数据加载和通信优化
os.environ['NCCL_BUFFSIZE'] = '8388608'  # 增加缓冲区大小到8MB
os.environ['NCCL_NTHREADS'] = '8'  # 增加NCCL线程数
os.environ['NCCL_MIN_NTHREADS'] = '4'  # 最小线程数

# 避免内存碎片和并行冲突
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # 避免tokenizer并行冲突
os.environ['OMP_NUM_THREADS'] = '4'  # 限制OpenMP线程数

# 数据集准备优化
os.environ['NCCL_P2P_DISABLE'] = '0'  # 确保P2P通信启用
os.environ['NCCL_SHM_DISABLE'] = '0'  # 确保共享内存通信启用

print("环境变量设置完成:")
for key in ['NCCL_TIMEOUT', 'TORCH_DISTRIBUTED_TIMEOUT', 'NCCL_DEBUG', 'NCCL_BUFFSIZE']:
    print(f"  {key}: {os.environ.get(key)}")

# 3. 定义带监控的训练函数
def run_training_with_monitoring():
    """带实时监控的训练启动函数"""
    
    # 获取当前工作目录中的preprocessed_data路径
    preprocessed_data = os.path.abspath("preprocessed_data")
    print(f"📁 数据集目录: {preprocessed_data}")
    
    # 训练配置参数
    training_config = {
        'data_config': 'training_data_config.yaml',
        'dataset_dir': preprocessed_data,  # 使用实际路径
        'model_size': '7b',
        'devices': 2,
        'num_nodes': 1,
        'seq_length': 1,
        'micro_batch_size': 1,
        'lr': 0.0001,
        'warmup_steps': 5,
        'max_steps': 200000,
        'ckpt_dir': 'nemo2_evo2_7b',
        'clip_grad': 1,
        'wd': 0.01,
        'activation_checkpoint_recompute_num_layers': 1,
        'val_check_interval': 1000
    }
    
    # 构建训练命令 - 使用正确的格式
    cmd_parts = [
        'train_evo2',
        '-d', training_config['data_config'],
        '--dataset-dir', training_config['dataset_dir'],
        '--model-size', training_config['model_size'],
        '--devices', str(training_config['devices']),
        '--num-nodes', str(training_config['num_nodes']),
        '--seq-length', str(training_config['seq_length']),
        '--micro-batch-size', str(training_config['micro_batch_size']),
        '--lr', str(training_config['lr']),
        '--warmup-steps', str(training_config['warmup_steps']),
        '--max-steps', str(training_config['max_steps']),
        '--ckpt-dir', training_config['ckpt_dir'],
        '--clip-grad', str(training_config['clip_grad']),
        '--wd', str(training_config['wd']),
        '--activation-checkpoint-recompute-num-layers', str(training_config['activation_checkpoint_recompute_num_layers']),
        '--val-check-interval', str(training_config['val_check_interval']),
        '--ckpt-async-save'
    ]
    
    cmd = ' '.join(cmd_parts)
    
    print(f"🚀 开始训练时间: {datetime.now()}")
    print(f"📋 训练命令: {cmd}")
    print("="*80)
    
    try:
        # 启动训练进程
        process = subprocess.Popen(
            cmd_parts,  # 使用列表而不是字符串，避免shell解析问题
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
            bufsize=1
        )
        
        # 实时监控输出
        start_time = time.time()
        last_output_time = start_time
        dataset_preparation_detected = False
        
        for line in iter(process.stdout.readline, ''):
            current_time = time.time()
            elapsed = current_time - start_time
            
            # 打印带时间戳的输出
            print(f"[{elapsed:.1f}s] {line.rstrip()}")
            
            # 检查关键信息
            keywords = ['dataset', 'preparing', 'loading', 'barrier', 'build', 'index']
            if any(keyword in line.lower() for keyword in keywords):
                print(f"📊 数据集准备阶段: {line.rstrip()}")
                dataset_preparation_detected = True
                last_output_time = current_time
            
            # NCCL相关信息特别标记
            if 'nccl' in line.lower():
                print(f"🔗 NCCL通信: {line.rstrip()}")
                last_output_time = current_time
            
            # 错误信息特别标记
            if any(err in line.lower() for err in ['error', 'timeout', 'fail']):
                print(f"❌ 错误信息: {line.rstrip()}")
                last_output_time = current_time
            
            # 长时间无输出的警告
            if current_time - last_output_time > 300:  # 5分钟无输出
                elapsed_no_output = current_time - last_output_time
                if dataset_preparation_detected:
                    print(f"\n⏳ [数据集准备] 已有 {elapsed_no_output:.1f} 秒无输出，数据集构建中...")
                else:
                    print(f"\n⏳ [等待中] 已有 {elapsed_no_output:.1f} 秒无输出...")
                last_output_time = current_time
        
        # 等待进程完成
        return_code = process.wait()
        
        if return_code == 0:
            print(f"\n✅ 训练成功完成! 总耗时: {time.time() - start_time:.1f} 秒")
        else:
            print(f"\n❌ 训练失败，返回码: {return_code}")
            
    except KeyboardInterrupt:
        print("\n🛑 用户中断训练")
        process.terminate()
    except Exception as e:
        print(f"\n💥 训练出错: {e}")

# 4. 启动训练
print("🎯 启动带监控的训练...")
run_training_with_monitoring()
