# 🧠 GDM-Net Google Colab 训练

Graph-Augmented Dual Memory Network for Multi-Document Understanding

本笔记本将帮助您在Google Colab上训练GDM-Net模型。

## 📋 使用前准备
1. 确保选择了GPU运行时：Runtime → Change runtime type → GPU
2. 准备好您的数据文件
3. 上传项目代码文件

## 🔧 1. 环境检查和设置

In [None]:
# 检查GPU可用性
import torch
import os

print("🔍 系统信息检查")
print("=" * 40)
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"PyTorch version: {torch.__version__}")
else:
    print("⚠️ No GPU available, using CPU")

# 显示详细GPU信息
!nvidia-smi

In [None]:
# 安装依赖包
print("📦 安装依赖包...")

# 安装PyTorch和相关包
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install torch-geometric
!pip install transformers>=4.20.0
!pip install pytorch-lightning>=1.7.0
!pip install datasets>=2.0.0
!pip install PyYAML>=6.0
!pip install tensorboard>=2.8.0
!pip install wandb>=0.12.0
!pip install tqdm>=4.64.0
!pip install scikit-learn>=1.1.0
!pip install matplotlib>=3.5.0
!pip install seaborn>=0.11.0

print("✅ 依赖安装完成")

## 📁 2. 项目文件准备

In [None]:
# 挂载Google Drive（可选）
from google.colab import drive
drive.mount('/content/drive')

# 如果您的项目文件在Google Drive中，可以复制到Colab
# !cp -r /content/drive/MyDrive/GDM-Net/* /content/

print("✅ Google Drive 挂载完成")

In [None]:
# 创建项目目录结构
import os

directories = [
    'gdmnet',
    'train', 
    'config',
    'data',
    'checkpoints',
    'logs',
    'examples'
]

for dir_name in directories:
    os.makedirs(dir_name, exist_ok=True)
    print(f"✅ 创建目录: {dir_name}")

print("\n📁 项目结构创建完成")

## ⚙️ 3. 配置文件创建

In [None]:
# 创建Colab优化的配置文件
colab_config = """
# GDM-Net Colab Configuration - GPU Optimized

seed: 42

model:
  bert_model_name: "bert-base-uncased"
  hidden_size: 768
  num_entities: 8
  num_relations: 4
  num_classes: 5
  gnn_type: "rgcn"
  num_gnn_layers: 2
  num_reasoning_hops: 3
  fusion_method: "gate"
  learning_rate: 2e-5
  dropout_rate: 0.1

data:
  train_path: "data/hotpotqa_train.json"
  val_path: "data/hotpotqa_val.json"
  test_path: "data/hotpotqa_val.json"
  max_length: 512
  max_query_length: 64

training:
  max_epochs: 10
  batch_size: 8
  num_workers: 2
  accelerator: "gpu"
  devices: 1
  precision: 16
  gradient_clip_val: 1.0
  accumulate_grad_batches: 1
  val_check_interval: 0.5
  log_every_n_steps: 50
  checkpoint_dir: "checkpoints"
  early_stopping: true
  patience: 3

logging:
  type: "tensorboard"
  save_dir: "logs"
  name: "gdmnet-colab"
"""

# 保存配置文件
with open('config/colab_config.yaml', 'w') as f:
    f.write(colab_config.strip())

print("✅ Colab配置文件创建完成")
print("📄 配置文件路径: config/colab_config.yaml")

## 📊 4. 数据准备

**请上传您的数据文件到 `data/` 目录，或从Google Drive复制。**

In [None]:
# 检查数据文件
import json
import os

def check_data_files():
    """检查数据文件是否存在"""
    data_files = [
        'data/hotpotqa_train.json',
        'data/hotpotqa_val.json'
    ]
    
    for file_path in data_files:
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                data = json.load(f)
            print(f"✅ {file_path}: {len(data)} 样本")
            
            # 显示样本
            if data:
                sample = data[0]
                print(f"   文档: {sample['document'][:100]}...")
                print(f"   查询: {sample['query']}")
                print(f"   实体: {len(sample['entities'])}, 关系: {len(sample['relations'])}")
                print()
        else:
            print(f"❌ {file_path}: 文件不存在")

# 检查数据
check_data_files()

# 如果没有数据文件，提供上传选项
if not os.path.exists('data/hotpotqa_train.json'):
    print("\n📤 请上传数据文件:")
    print("1. 使用左侧文件面板上传到 data/ 目录")
    print("2. 或从Google Drive复制:")
    print("   !cp /content/drive/MyDrive/path/to/your/data/* ./data/")

## 🧠 5. 模型代码部署

**请确保已上传所有GDM-Net模型文件到对应目录。**

In [None]:
# 检查模型文件
import sys
sys.path.append('/content')

def check_model_files():
    """检查模型文件是否存在"""
    required_files = [
        'gdmnet/__init__.py',
        'gdmnet/model.py',
        'gdmnet/encoder.py',
        'gdmnet/extractor.py',
        'gdmnet/graph_memory.py',
        'gdmnet/reasoning.py',
        'train/train.py',
        'train/dataset.py'
    ]
    
    all_exist = True
    for file_path in required_files:
        if os.path.exists(file_path):
            print(f"✅ {file_path}")
        else:
            print(f"❌ {file_path}")
            all_exist = False
    
    return all_exist

# 检查文件
files_ok = check_model_files()

if files_ok:
    # 测试导入
    try:
        from gdmnet import GDMNet
        print("\n✅ GDM-Net模型导入成功")
    except ImportError as e:
        print(f"\n❌ 模型导入失败: {e}")
else:
    print("\n❌ 缺少必要的模型文件，请上传完整的项目代码")

## 🏋️ 6. 开始训练

In [None]:
# 设置训练环境
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# 开始训练
print("🚀 启动GDM-Net训练...")
print("=" * 50)

!python train/train.py --config config/colab_config.yaml --mode train

print("\n🎉 训练完成！")

## 📈 7. 监控训练进度

In [None]:
# 启动TensorBoard
%load_ext tensorboard
%tensorboard --logdir logs/

print("📊 TensorBoard已启动，您可以在上方看到训练曲线")

In [None]:
# 检查训练进度
import glob
import os

def check_training_progress():
    """检查训练进度"""
    print("📊 训练进度检查")
    print("=" * 30)
    
    # 检查检查点
    checkpoints = glob.glob('checkpoints/*.ckpt')
    if checkpoints:
        print(f"💾 找到 {len(checkpoints)} 个检查点:")
        for ckpt in sorted(checkpoints):
            size = os.path.getsize(ckpt) / (1024*1024)
            print(f"   {os.path.basename(ckpt)} ({size:.1f} MB)")
        
        # 找到最佳模型
        best_model = min(checkpoints, key=lambda x: float(x.split('val_loss=')[1].split('-')[0]))
        print(f"\n🏆 最佳模型: {os.path.basename(best_model)}")
    else:
        print("❌ 未找到检查点文件")
    
    # 检查日志
    log_dirs = glob.glob('logs/gdmnet-colab/version_*')
    if log_dirs:
        print(f"\n📋 找到 {len(log_dirs)} 个日志目录")
    else:
        print("\n❌ 未找到日志目录")

# 检查进度
check_training_progress()

## 🧪 8. 模型测试和推理

In [None]:
# 加载训练好的模型
import torch
import glob
from gdmnet import GDMNet

def load_best_model():
    """加载最佳模型"""
    checkpoints = glob.glob('checkpoints/*.ckpt')
    if not checkpoints:
        print("❌ 未找到检查点文件")
        return None
    
    # 找到验证损失最低的模型
    best_model_path = min(checkpoints, key=lambda x: float(x.split('val_loss=')[1].split('-')[0]))
    print(f"🧠 加载最佳模型: {best_model_path}")
    
    # 加载模型
    model = GDMNet.load_from_checkpoint(best_model_path)
    model.eval()
    
    if torch.cuda.is_available():
        model = model.cuda()
        print("🔥 模型已移至GPU")
    
    print(f"✅ 模型加载成功")
    print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
    
    return model

# 加载模型
trained_model = load_best_model()

In [None]:
# 运行推理示例
from transformers import BertTokenizer
import torch.nn.functional as F

def run_inference_demo(model):
    """运行推理演示"""
    if model is None:
        print("❌ 模型未加载")
        return
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # 示例输入
    examples = [
        {
            "document": "Apple Inc. is a technology company founded by Steve Jobs. Tim Cook is the current CEO.",
            "query": "Who is the CEO of Apple?"
        },
        {
            "document": "Microsoft Corporation was founded by Bill Gates. Satya Nadella is the current CEO.",
            "query": "Who founded Microsoft?"
        }
    ]
    
    print("🔍 推理演示")
    print("=" * 40)
    
    for i, example in enumerate(examples, 1):
        print(f"\n📋 示例 {i}:")
        print(f"文档: {example['document']}")
        print(f"查询: {example['query']}")
        
        # 编码输入
        doc_encoding = tokenizer(
            example['document'],
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        query_encoding = tokenizer(
            example['query'],
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # 移至GPU（如果可用）
        if torch.cuda.is_available():
            doc_encoding = {k: v.cuda() for k, v in doc_encoding.items()}
            query_encoding = {k: v.cuda() for k, v in query_encoding.items()}
        
        # 推理
        with torch.no_grad():
            outputs = model(
                input_ids=doc_encoding['input_ids'],
                attention_mask=doc_encoding['attention_mask'],
                query=query_encoding['input_ids'],
                return_intermediate=True
            )
        
        # 显示结果
        logits = outputs['logits']
        probabilities = F.softmax(logits, dim=-1)
        prediction = torch.argmax(logits, dim=-1)
        confidence = probabilities.max()
        
        print(f"🎯 预测类别: {prediction.item()}")
        print(f"📊 置信度: {confidence.item():.3f}")
        print(f"🔍 提取实体: {len(outputs['entities'][0])}")
        print(f"🔗 提取关系: {len(outputs['relations'][0])}")

# 运行推理演示
if 'trained_model' in locals() and trained_model is not None:
    run_inference_demo(trained_model)
else:
    print("⚠️ 请先加载训练好的模型")

## 💾 9. 保存和下载结果

In [None]:
# 保存结果到Google Drive
def save_to_drive():
    """保存训练结果到Google Drive"""
    result_dir = '/content/drive/MyDrive/GDM-Net-Results'
    os.makedirs(result_dir, exist_ok=True)
    
    print("💾 保存结果到Google Drive...")
    
    # 复制检查点
    if os.path.exists('checkpoints'):
        !cp -r checkpoints/* /content/drive/MyDrive/GDM-Net-Results/
        print("✅ 检查点已保存")
    
    # 复制日志
    if os.path.exists('logs'):
        !cp -r logs/* /content/drive/MyDrive/GDM-Net-Results/
        print("✅ 日志已保存")
    
    # 保存配置
    !cp config/colab_config.yaml /content/drive/MyDrive/GDM-Net-Results/
    print("✅ 配置文件已保存")
    
    print(f"\n🎉 所有结果已保存到: {result_dir}")

# 执行保存
save_to_drive()

In [None]:
# 下载最佳模型到本地
from google.colab import files
import glob

def download_best_model():
    """下载最佳模型"""
    checkpoints = glob.glob('checkpoints/*.ckpt')
    if checkpoints:
        best_model = min(checkpoints, key=lambda x: float(x.split('val_loss=')[1].split('-')[0]))
        print(f"📥 下载最佳模型: {best_model}")
        files.download(best_model)
        print("✅ 下载完成")
    else:
        print("❌ 未找到检查点文件")

# 下载模型（可选）
# download_best_model()

## 🎉 训练完成！

恭喜您成功在Google Colab上训练了GDM-Net模型！

### 📋 后续步骤：
1. 查看TensorBoard中的训练曲线
2. 使用训练好的模型进行推理
3. 保存重要结果到Google Drive
4. 下载最佳模型到本地

### 🔧 进一步优化：
- 调整超参数（学习率、批次大小等）
- 尝试不同的融合策略
- 使用更大的数据集
- 实验不同的GNN架构

### 📞 获取帮助：
如果遇到问题，请检查：
1. GPU是否正确启用
2. 所有依赖是否正确安装
3. 数据文件格式是否正确
4. 模型文件是否完整上传

祝您研究顺利！🚀