In [16]:
# 导入必要的库
import dgl
import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import os
from collections import defaultdict
import pickle
import logging
import time
import gc
import psutil
from IPython.display import clear_output

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 配置日志
logging.basicConfig(filename='process.log', level=logging.INFO, 
                    format='%(asctime)s - %(message)s')
logger = logging.getLogger()

# 数据集路径
DATA_PATH = r".\Data\StreamSpot\all.tsv"
OUTPUT_DIR = r".\Data\StreamSpot\processed"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 检查环境
print(f"PyTorch版本: {torch.__version__}")
print(f"DGL版本: {dgl.__version__}")
print(f"文件是否存在: {os.path.exists(DATA_PATH)}")
print(f"文件大小: {os.path.getsize(DATA_PATH) / 1024**3:.2f} GB")
print(f"初始内存使用: {psutil.virtual_memory().used / 1024**3:.2f} GB")
print(f"可用CPU核数: {os.cpu_count()}")
logger.info("环境检查完成")
start_time = time.time()
print(f"开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
clear_output(wait=True)

PyTorch版本: 2.7.0+cpu
DGL版本: 2.0.0
文件是否存在: True
文件大小: 2.13 GB
初始内存使用: 15.04 GB
可用CPU核数: 22
开始时间: 2025-06-29 20:21:14


In [17]:
# 加载和预处理StreamSpot数据集
def preprocess_data_chunks(data_path, chunk_size=500000):
    """
    分块加载TSV文件，编码节点、节点类型和边类型
    参数：
        data_path: TSV文件路径
        chunk_size: 每次读取的行数
    返回：
        节点编码器、节点类型编码器、边类型编码器
    """
    start_preprocess = time.time()
    logger.info("开始预处理")
    print("开始预处理...")
    
    node_encoder = LabelEncoder()
    node_type_encoder = LabelEncoder()
    edge_encoder = LabelEncoder()
    
    nodes = set()
    node_types = set()
    edge_types = set()
    
    for chunk_idx, chunk in enumerate(pd.read_csv(data_path, sep='\t', header=None,
                                                 names=["source_id", "source_type", "dest_id", 
                                                        "dest_type", "edge_type", "graph_id"],
                                                 dtype={"source_id": str, "dest_id": str, "graph_id": int},
                                                 chunksize=chunk_size)):
        chunk_start = time.time()
        nodes.update(chunk["source_id"].astype(str))
        nodes.update(chunk["dest_id"].astype(str))
        node_types.update(chunk["source_type"])
        node_types.update(chunk["dest_type"])
        edge_types.update(chunk["edge_type"])
        print(f"处理chunk {chunk_idx}，行数: {len(chunk)}，耗时: {time.time() - chunk_start:.2f}秒")
        logger.info(f"处理chunk {chunk_idx}，行数: {len(chunk)}，内存: {psutil.virtual_memory().used / 1024**3:.2f} GB")
        gc.collect()
    
    node_encoder.fit(list(nodes))
    node_type_encoder.fit(list(node_types))
    edge_encoder.fit(list(edge_types))
    
    with open(os.path.join(OUTPUT_DIR, "node_encoder.pkl"), "wb") as f:
        pickle.dump(node_encoder, f)
    with open(os.path.join(OUTPUT_DIR, "node_type_encoder.pkl"), "wb") as f:
        pickle.dump(node_type_encoder, f)
    with open(os.path.join(OUTPUT_DIR, "edge_encoder.pkl"), "wb") as f:
        pickle.dump(edge_encoder, f)
    
    print(f"节点数: {len(nodes)}, 节点类型数: {len(node_types)}, 边类型数: {len(edge_types)}")
    print(f"预处理总耗时: {time.time() - start_preprocess:.2f}秒")
    print(f"内存使用: {psutil.virtual_memory().used / 1024**3:.2f} GB")
    logger.info(f"预处理完成，节点数: {len(nodes)}，耗时: {time.time() - start_preprocess:.2f}秒")
    return node_encoder, node_type_encoder, edge_encoder

# 执行预处理
print("执行预处理...")
node_encoder, node_type_encoder, edge_encoder = preprocess_data_chunks(DATA_PATH)
clear_output(wait=True)

执行预处理...
开始预处理...
处理chunk 0，行数: 500000，耗时: 0.20秒
处理chunk 1，行数: 500000，耗时: 0.19秒
处理chunk 2，行数: 500000，耗时: 0.19秒
处理chunk 3，行数: 500000，耗时: 0.19秒
处理chunk 4，行数: 500000，耗时: 0.23秒
处理chunk 5，行数: 500000，耗时: 0.22秒
处理chunk 6，行数: 500000，耗时: 0.21秒
处理chunk 7，行数: 500000，耗时: 0.20秒
处理chunk 8，行数: 500000，耗时: 0.20秒
处理chunk 9，行数: 500000，耗时: 0.21秒
处理chunk 10，行数: 500000，耗时: 0.19秒
处理chunk 11，行数: 500000，耗时: 0.20秒
处理chunk 12，行数: 500000，耗时: 0.19秒
处理chunk 13，行数: 500000，耗时: 0.26秒
处理chunk 14，行数: 500000，耗时: 0.21秒
处理chunk 15，行数: 500000，耗时: 0.20秒
处理chunk 16，行数: 500000，耗时: 0.24秒
处理chunk 17，行数: 500000，耗时: 0.24秒
处理chunk 18，行数: 500000，耗时: 0.20秒
处理chunk 19，行数: 500000，耗时: 0.22秒
处理chunk 20，行数: 500000，耗时: 0.28秒
处理chunk 21，行数: 500000，耗时: 0.21秒
处理chunk 22，行数: 500000，耗时: 0.22秒
处理chunk 23，行数: 500000，耗时: 0.24秒
处理chunk 24，行数: 500000，耗时: 0.21秒
处理chunk 25，行数: 500000，耗时: 0.21秒
处理chunk 26，行数: 500000，耗时: 0.24秒
处理chunk 27，行数: 500000，耗时: 0.27秒
处理chunk 28，行数: 500000，耗时: 0.23秒
处理chunk 29，行数: 500000，耗时: 0.26秒
处理chunk 30，行数: 500000，耗时: 0.21秒


In [18]:
# 构建DGL图
def build_dgl_graph(src_ids, dst_ids, edge_types, node_types):
    """
    从边列表和节点类型构建DGL图
    参数：
        src_ids: 源节点ID列表
        dst_ids: 目标节点ID列表
        edge_types: 边类型列表
        node_types: 节点ID到节点类型的字典
    返回：
        DGL图
    """
    start_build = time.time()
    
    # 重新映射节点ID为连续整数
    unique_nodes = set(src_ids) | set(dst_ids)
    node_map = {nid: i for i, nid in enumerate(unique_nodes)}
    src_ids_mapped = [node_map[nid] for nid in src_ids]
    dst_ids_mapped = [node_map[nid] for nid in dst_ids]
    
    g = dgl.graph((torch.tensor(src_ids_mapped, dtype=torch.int64), 
                   torch.tensor(dst_ids_mapped, dtype=torch.int64)))
    g.edata["type"] = torch.tensor(edge_types, dtype=torch.int64)
    
    # 设置节点类型
    node_type_list = [node_types.get(nid, 0) for nid in unique_nodes]
    g.ndata["type"] = torch.tensor(node_type_list, dtype=torch.int64)
    g.ndata["id"] = torch.tensor(list(unique_nodes), dtype=torch.int64)
    
    print(f"构建图，节点数: {len(unique_nodes)}，边数: {len(src_ids)}，耗时: {time.time() - start_build:.2f}秒")
    logger.info(f"构建图，节点数: {len(unique_nodes)}，耗时: {time.time() - start_build:.2f}秒")
    return g

In [35]:
# 优化后的快照生成
import time
import gc
import logging
import pandas as pd
import torch
import os
import psutil
import numpy as np
from IPython.display import clear_output
from collections import defaultdict

# 配置全局开始时间（需与 Cell 1 一致）
try:
    start_time
except NameError:
    start_time = time.time()

# 配置日志到 OUTPUT_DIR、备用日志和控制台
OUTPUT_DIR = r".\Data\StreamSpot\processed"
log_path = os.path.abspath(os.path.join(OUTPUT_DIR, "process.log"))
fallback_log_path = os.path.abspath("fallback_process.log")
try:
    logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(message)s', filemode='w')
    logger = logging.getLogger()
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
    logger.addHandler(console_handler)
    logger.info("主日志配置成功")
    print(f"主日志配置成功: {log_path}")
except Exception as e:
    print(f"主日志配置失败: {e}")
    try:
        logging.basicConfig(filename=fallback_log_path, level=logging.INFO, format='%(asctime)s - %(message)s', filemode='w')
        logger = logging.getLogger()
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
        logger.addHandler(console_handler)
        logger.info("备用日志配置成功")
        print(f"备用日志配置成功: {fallback_log_path}")
    except Exception as e:
        print(f"备用日志配置失败: {e}")
        logger = logging.getLogger()
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
        logger.addHandler(console_handler)
        logger.info("仅控制台日志")

def generate_snapshots_chunks(data_path, node_encoder, node_type_encoder, edge_encoder, 
                             n=300, fr=1/3, chunk_size=500000, output_dir=OUTPUT_DIR):
    """
    分块生成时间序列快照，保存为DGL图，处理所有边数 <= 50000 的图
    参数：
        data_path: TSV文件路径
        node_encoder: 节点编码器
        node_type_encoder: 节点类型编码器
        edge_encoder: 边类型编码器
        n: 快照大小
        fr: 遗忘率
        chunk_size: 每次读取的行数
        output_dir: 输出目录
    返回：
        生成的快照数量
    """
    start_snapshots = time.time()
    logger.info("开始生成快照")
    print("开始生成快照...")
    clear_output(wait=True)
    
    cache_graph = defaultdict(list)
    node_timestamps = {}
    node_types = {}
    node_count = 0
    snapshot_id = 0
    edge_batch_count = 0
    
    def print_status(message):
        """刷新式打印状态"""
        clear_output(wait=True)
        print(f"总时间: {time.time() - start_snapshots:.2f}秒, 内存: {psutil.virtual_memory().used / 1024**3:.2f} GB")
        print(message)
        logger.info(message)
    
    def save_snapshot(g, sid):
        """保存快照到文件"""
        start_save = time.time()
        try:
            if not os.access(output_dir, os.W_OK):
                raise PermissionError(f"无写入权限: {output_dir}")
            if psutil.disk_usage(output_dir).free < 1024**3:
                raise RuntimeError(f"磁盘空间不足: {output_dir}")
            snapshot_path = os.path.join(output_dir, f"snapshot_{sid}.pt")
            torch.save(g, snapshot_path)
            if not os.path.exists(snapshot_path):
                raise RuntimeError(f"文件未生成: {snapshot_path}")
            print_status(f"保存快照 {sid}, 耗时: {time.time() - start_save:.2f}秒, 文件: {snapshot_path}")
            logger.info(f"保存快照 {sid}, 耗时: {time.time() - start_save:.2f}秒, 文件: {snapshot_path}")
        except Exception as e:
            print_status(f"保存快照 {sid} 失败: {e}")
            logger.error(f"保存快照 {sid} 失败: {e}")
    
    def generate_snapshot():
        """从缓存图生成快照"""
        nonlocal snapshot_id
        if not cache_graph:
            print_status(f"尝试生成快照 {snapshot_id}, 但cache_graph为空")
            logger.warning(f"尝试生成快照 {snapshot_id}, 但cache_graph为空")
            return False
        src_ids, dst_ids, edge_types = [], [], []
        for (s, d), etypes in cache_graph.items():
            src_ids.append(s)
            dst_ids.append(d)
            edge_types.append(max(etypes))
        print_status(f"生成快照 {snapshot_id}, cache_graph大小: {len(cache_graph)}, 节点数: {node_count}, 边数: {len(src_ids)}")
        logger.info(f"生成快照 {snapshot_id}, cache_graph大小: {len(cache_graph)}, 节点数: {node_count}, 边数: {len(src_ids)}")
        try:
            g = build_dgl_graph(src_ids, dst_ids, edge_types, node_types)
            print_status(f"生成快照 {snapshot_id}, DGL图: 节点数={g.num_nodes()}, 边数={g.num_edges()}")
            logger.info(f"生成快照 {snapshot_id}, DGL图: 节点数={g.num_nodes()}, 边数={g.num_edges()}")
            save_snapshot(g, snapshot_id)
            snapshot_id += 1
            cache_graph.clear()  # 清空cache_graph
            return True
        except Exception as e:
            print_status(f"生成快照 {snapshot_id} 失败: {e}")
            logger.error(f"生成快照 {snapshot_id} 失败: {e}")
            return False
    
    processed_graph_ids = set()
    for chunk_idx, chunk in enumerate(pd.read_csv(data_path, sep='\t', header=None,
                                                 names=["source_id", "source_type", "dest_id", 
                                                        "dest_type", "edge_type", "graph_id"],
                                                 dtype={"source_id": str, "dest_id": str, "graph_id": int},
                                                 chunksize=chunk_size)):
        chunk_start = time.time()
        print_status(f"处理chunk {chunk_idx}...")
        logger.info(f"处理chunk {chunk_idx}")
        
        # 预编码
        try:
            chunk["source_id_enc"] = node_encoder.transform(chunk["source_id"].astype(str))
            chunk["dest_id_enc"] = node_encoder.transform(chunk["dest_id"].astype(str))
            chunk["source_type_enc"] = node_type_encoder.transform(chunk["source_type"])
            chunk["dest_type_enc"] = node_type_encoder.transform(chunk["dest_type"])
            chunk["edge_type_enc"] = edge_encoder.transform(chunk["edge_type"])
            print_status(f"chunk {chunk_idx} 编码完成，行数: {len(chunk)}, 样本: {chunk[['source_id_enc', 'dest_id_enc', 'edge_type_enc']].head(1).to_string()}")
            logger.info(f"chunk {chunk_idx} 编码完成，行数: {len(chunk)}, 样本: {chunk[['source_id_enc', 'dest_id_enc', 'edge_type_enc']].head(1).to_string()}")
        except Exception as e:
            print_status(f"chunk {chunk_idx} 编码失败: {e}")
            logger.error(f"chunk {chunk_idx} 编码失败: {e}")
            continue
        
        graph_id_counts = chunk["graph_id"].value_counts()
        valid_graph_ids = [gid for gid in graph_id_counts.index if graph_id_counts[gid] <= 50000 and gid not in processed_graph_ids]
        print_status(f"chunk {chunk_idx} 有效graph_id: {valid_graph_ids}")
        logger.info(f"chunk {chunk_idx} 有效graph_id: {valid_graph_ids}")
        
        for graph_id in valid_graph_ids:
            processed_graph_ids.add(graph_id)
            edge_count = graph_id_counts[graph_id]
            print_status(f"开始处理graph_id: {graph_id}, 边数: {edge_count}")
            logger.info(f"开始处理graph_id: {graph_id}, 边数: {edge_count}")
            
            sub_df = chunk[chunk["graph_id"] == graph_id]
            try:
                edge_start = time.time()
                for idx in range(len(sub_df)):
                    if idx % 100 == 0:
                        print_status(f"处理graph_id {graph_id}，进度: {idx}/{len(sub_df)}，节点数: {node_count}, cache_graph大小: {len(cache_graph)}")
                        logger.info(f"graph_id {graph_id}，进度: {idx}/{len(sub_df)}，节点数: {node_count}, cache_graph大小: {len(cache_graph)}")
                    
                    row = sub_df.iloc[idx]
                    src = row["source_id_enc"]
                    dst = row["dest_id_enc"]
                    edge_type = row["edge_type_enc"]
                    src_type = row["source_type_enc"]
                    dst_type = row["dest_type_enc"]
                    
                    if not isinstance(src, (int, np.integer)) or not isinstance(dst, (int, np.integer)):
                        print_status(f"graph_id {graph_id}，无效编码: src={src}, dst={dst}")
                        logger.error(f"graph_id {graph_id}，无效编码: src={src}, dst={dst}")
                        continue
                    
                    if src not in node_timestamps:
                        node_timestamps[src] = len(node_timestamps)
                        node_types[src] = src_type
                        node_count += 1
                    if dst not in node_timestamps:
                        node_timestamps[dst] = len(node_timestamps)
                        node_types[dst] = dst_type
                        node_count += 1
                    cache_graph[(src, dst)].append(edge_type)
                    edge_batch_count += 1
                    
                    # 每500条边尝试生成快照
                    if edge_batch_count >= 500:
                        if generate_snapshot():
                            sorted_nodes = sorted(node_timestamps.items(), key=lambda x: x[1])
                            remove_count = int(n * fr)
                            for node, _ in sorted_nodes[:remove_count]:
                                del node_timestamps[node]
                                del node_types[node]
                                cache_graph.pop((node, None), None)
                                cache_graph.pop((None, node), None)
                            node_count = len(node_timestamps)
                            gc.collect()
                            print_status(f"清理节点后，剩余节点: {node_count}, 内存: {psutil.virtual_memory().used / 1024**3:.2f} GB")
                            logger.info(f"快照 {snapshot_id} 生成，剩余节点: {node_count}")
                        edge_batch_count = 0
                
                print_status(f"处理graph_id {graph_id} 边耗时: {time.time() - edge_start:.2f}秒")
                logger.info(f"处理graph_id {graph_id} 边耗时: {time.time() - edge_start:.2f}秒")
                
                # graph_id 结束时强制生成快照
                if cache_graph:
                    generate_snapshot()
                    sorted_nodes = sorted(node_timestamps.items(), key=lambda x: x[1])
                    remove_count = int(n * fr)
                    for node, _ in sorted_nodes[:remove_count]:
                        del node_timestamps[node]
                        del node_types[node]
                        cache_graph.pop((node, None), None)
                        cache_graph.pop((None, node), None)
                    node_count = len(node_timestamps)
                    gc.collect()
                    print_status(f"graph_id {graph_id} 清理节点后，剩余节点: {node_count}, 内存: {psutil.virtual_memory().used / 1024**3:.2f} GB")
                
                print_status(f"处理graph_id {graph_id} 完成，耗时: {time.time() - chunk_start:.2f}秒")
                logger.info(f"处理graph_id {graph_id} 完成，耗时: {time.time() - chunk_start:.2f}秒")
                
            except Exception as e:
                print_status(f"处理graph_id {graph_id} 失败: {e}")
                logger.error(f"处理graph_id {graph_id} 失败: {e}")
        
        # chunk 结束时强制生成快照
        if cache_graph:
            generate_snapshot()
        
        print_status(f"处理chunk {chunk_idx}，耗时: {time.time() - chunk_start:.2f}秒")
        logger.info(f"处理chunk {chunk_idx}，耗时: {time.time() - chunk_start:.2f}秒")
        gc.collect()
    
    if cache_graph:
        generate_snapshot()
    
    print_status(f"生成 {snapshot_id} 个快照，保存到 {output_dir}")
    print(f"快照生成总耗时: {time.time() - start_snapshots:.2f}秒")
    print(f"总运行时间: {time.time() - start_time:.2f}秒")
    logger.info(f"生成 {snapshot_id} 个快照，总耗时: {time.time() - start_snapshots:.2f}秒")
    
    # 验证日志和快照文件
    for path in [log_path, fallback_log_path]:
        if os.path.exists(path):
            print_status(f"日志文件生成: {path}")
            logger.info(f"日志文件生成: {path}")
            with open(path, "r", encoding="utf-8") as f:
                print_status(f"{path} 最后10行:\n{''.join(f.readlines()[-10:])}")
        else:
            print_status(f"日志文件未生成: {path}")
            logger.error(f"日志文件未生成: {path}")
    
    pt_files = [f for f in os.listdir(output_dir) if f.endswith('.pt') and f.startswith('snapshot_')]
    print_status(f"生成 {len(pt_files)} 个快照文件: {pt_files[:5]}")
    logger.info(f"生成 {len(pt_files)} 个快照文件: {pt_files[:5]}")
    
    return snapshot_id

# 执行快照生成
print("执行快照生成...")
snapshot_count = generate_snapshots_chunks(DATA_PATH, node_encoder, node_type_encoder, edge_encoder)
clear_output(wait=True)

2025-06-30 01:04:26,849 - 生成 16625 个快照文件: ['snapshot_0.pt', 'snapshot_1.pt', 'snapshot_10.pt', 'snapshot_100.pt', 'snapshot_1000.pt']
2025-06-30 01:04:26,849 - 生成 16625 个快照文件: ['snapshot_0.pt', 'snapshot_1.pt', 'snapshot_10.pt', 'snapshot_100.pt', 'snapshot_1000.pt']
2025-06-30 01:04:26,850 - 生成 16625 个快照文件: ['snapshot_0.pt', 'snapshot_1.pt', 'snapshot_10.pt', 'snapshot_100.pt', 'snapshot_1000.pt']
2025-06-30 01:04:26,850 - 生成 16625 个快照文件: ['snapshot_0.pt', 'snapshot_1.pt', 'snapshot_10.pt', 'snapshot_100.pt', 'snapshot_1000.pt']


总时间: 12893.26秒, 内存: 16.38 GB
生成 16625 个快照文件: ['snapshot_0.pt', 'snapshot_1.pt', 'snapshot_10.pt', 'snapshot_100.pt', 'snapshot_1000.pt']


In [26]:
import torch
import dgl
import os
OUTPUT_DIR = r".\Data\StreamSpot\processed"
g = dgl.graph(([0, 1], [1, 2]))
try:
    torch.save(g, os.path.join(OUTPUT_DIR, "test.pt"))
    print(f"测试保存成功: {os.path.exists(os.path.join(OUTPUT_DIR, 'test.pt'))}")
except Exception as e:
    print(f"测试保存失败: {e}")

测试保存成功: True


In [30]:
import logging
import os
OUTPUT_DIR = r".\Data\StreamSpot\processed"
log_path = os.path.abspath(os.path.join(OUTPUT_DIR, "process.log"))
fallback_log_path = os.path.abspath("fallback_process.log")
try:
    logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(message)s', filemode='w')
    logger = logging.getLogger()
    logger.info("测试日志写入")
    print(f"主日志文件是否存在: {os.path.exists(log_path)}")
    if os.path.exists(log_path):
        with open(log_path, "r", encoding="utf-8") as f:
            print(f"主日志内容:\n{f.read()}")
except Exception as e:
    print(f"主日志配置失败: {e}")
    try:
        logging.basicConfig(filename=fallback_log_path, level=logging.INFO, format='%(asctime)s - %(message)s', filemode='w')
        logger = logging.getLogger()
        logger.info("测试备用日志写入")
        print(f"备用日志文件是否存在: {os.path.exists(fallback_log_path)}")
        if os.path.exists(fallback_log_path):
            with open(fallback_log_path, "r", encoding="utf-8") as f:
                print(f"备用日志内容:\n{f.read()}")
    except Exception as e:
        print(f"备用日志配置失败: {e}")
print(f"OUTPUT_DIR 绝对路径: {os.path.abspath(OUTPUT_DIR)}")
print(f"当前工作目录: {os.getcwd()}")

主日志文件是否存在: False
OUTPUT_DIR 绝对路径: C:\Users\cli305\Codes\Jupyter\Provenance Graph Embedding\Prographer on Provenance Graph Embedding\Data\StreamSpot\processed
当前工作目录: C:\Users\cli305\Codes\Jupyter\Provenance Graph Embedding\Prographer on Provenance Graph Embedding


In [31]:
import pandas as pd
DATA_PATH = r".\Data\StreamSpot\all.tsv"
chunk_size = 500000
chunk = next(pd.read_csv(DATA_PATH, sep='\t', header=None,
                         names=["source_id", "source_type", "dest_id", 
                                "dest_type", "edge_type", "graph_id"],
                         dtype={"source_id": str, "dest_id": str, "graph_id": int},
                         chunksize=chunk_size))
graph_id_counts = chunk["graph_id"].value_counts()
valid_graph_ids = [gid for gid in graph_id_counts.index if 0 <= gid <= 9]
print(f"第一个chunk的graph_id: {list(graph_id_counts.index)}")
print(f"graph_id 0-9 边数: {[f'graph_id {gid}: {graph_id_counts.get(gid, 0)}' for gid in range(10)]}")
print(f"有效graph_id (边数 <= 20000): {[gid for gid in valid_graph_ids if graph_id_counts[gid] <= 20000]}")

第一个chunk的graph_id: [2, 0, 3, 1, 4]
graph_id 0-9 边数: ['graph_id 0: 104968', 'graph_id 1: 76512', 'graph_id 2: 210845', 'graph_id 3: 79365', 'graph_id 4: 28310', 'graph_id 5: 0', 'graph_id 6: 0', 'graph_id 7: 0', 'graph_id 8: 0', 'graph_id 9: 0']
有效graph_id (边数 <= 20000): []


In [32]:
import pandas as pd
from collections import defaultdict
DATA_PATH = r".\Data\StreamSpot\all.tsv"
chunk_size = 500000
chunk = next(pd.read_csv(DATA_PATH, sep='\t', header=None,
                         names=["source_id", "source_type", "dest_id", 
                                "dest_type", "edge_type", "graph_id"],
                         dtype={"source_id": str, "dest_id": str, "graph_id": int},
                         chunksize=chunk_size))
chunk["source_id_enc"] = node_encoder.transform(chunk["source_id"].astype(str))
chunk["dest_id_enc"] = node_encoder.transform(chunk["dest_id"].astype(str))
chunk["source_type_enc"] = node_type_encoder.transform(chunk["source_type"])
chunk["dest_type_enc"] = node_type_encoder.transform(chunk["dest_type"])
chunk["edge_type_enc"] = edge_encoder.transform(chunk["edge_type"])
graph_id_counts = chunk["graph_id"].value_counts()
valid_graph_ids = [gid for gid in graph_id_counts.index if 0 <= gid <= 9 and graph_id_counts[gid] <= 50000]
if not valid_graph_ids:
    print("无有效graph_id")
else:
    graph_id = valid_graph_ids[0]
    sub_df = chunk[chunk["graph_id"] == graph_id]
    cache_graph = defaultdict(list)
    node_timestamps = {}
    node_types = {}
    node_count = 0
    edge_batch_count = 0
    for idx in range(min(1000, len(sub_df))):
        row = sub_df.iloc[idx]
        src = row["source_id_enc"]
        dst = row["dest_id_enc"]
        edge_type = row["edge_type_enc"]
        src_type = row["source_type_enc"]
        dst_type = row["dest_type_enc"]
        if not isinstance(src, (int, np.integer)) or not isinstance(dst, (int, np.integer)):
            print(f"无效编码: src={src}, dst={dst}")
            continue
        if src not in node_timestamps:
            node_timestamps[src] = len(node_timestamps)
            node_types[src] = src_type
            node_count += 1
        if dst not in node_timestamps:
            node_timestamps[dst] = len(node_timestamps)
            node_types[dst] = dst_type
            node_count += 1
        cache_graph[(src, dst)].append(edge_type)
        edge_batch_count += 1
    print(f"graph_id {graph_id}, 边数: {len(sub_df)}, cache_graph大小: {len(cache_graph)}, 节点数: {node_count}, 边批次: {edge_batch_count}")

graph_id 4, 边数: 28310, cache_graph大小: 557, 节点数: 558, 边批次: 1000


In [33]:
import dgl
import torch
import os
from collections import defaultdict
OUTPUT_DIR = r".\Data\StreamSpot\processed"
cache_graph = defaultdict(list)
node_types = {}
for i in range(500):
    src, dst = i, i + 1
    cache_graph[(src, dst)].append(0)
    node_types[src] = 0
    node_types[dst] = 0
src_ids, dst_ids, edge_types = [], [], []
for (s, d), etypes in cache_graph.items():
    src_ids.append(s)
    dst_ids.append(d)
    edge_types.append(max(etypes))
try:
    g = dgl.graph((torch.tensor(src_ids, dtype=torch.int64), torch.tensor(dst_ids, dtype=torch.int64)))
    g.edata["type"] = torch.tensor(edge_types, dtype=torch.int64)
    node_type_list = [node_types.get(i, 0) for i in range(g.num_nodes())]
    g.ndata["type"] = torch.tensor(node_type_list, dtype=torch.int64)
    g.ndata["id"] = torch.arange(g.num_nodes(), dtype=torch.int64)
    print(f"DGL图生成成功: 节点数={g.num_nodes()}, 边数={g.num_edges()}")
    snapshot_path = os.path.join(OUTPUT_DIR, "snapshot_test.pt")
    torch.save(g, snapshot_path)
    print(f"快照保存成功: {os.path.exists(snapshot_path)}")
except Exception as e:
    print(f"快照生成或保存失败: {e}")

DGL图生成成功: 节点数=501, 边数=500
快照保存成功: True
