In [None]:
# 加载spark配置
from xinghe.spark import *
from app.common.json_util import *
from xinghe.s3 import *
from xinghe.s3.read import *
from xinghe.ops.spark import spark_resize_file

import os
os.environ["LLM_WEB_KIT_CFG_PATH"] = "/share/xxx/.llm-web-kit.jsonc"

import heapq
import math
import uuid
import traceback
import pandas as pd
from typing import List, Dict
from copy import copy
from collections import defaultdict
from func_timeout import FunctionTimedOut, func_timeout
from datetime import datetime

from llm_web_kit.html_layout.html_layout_cosin import cluster_html_struct, get_feature, similarity, sum_tags

from pyspark.sql.window import Window
from pyspark.sql import Row, DataFrame
from pyspark.sql.functions import when, spark_partition_id, row_number, col, collect_list, struct, from_json, expr, count as _count, pandas_udf, PandasUDFType, round as _round, lit, to_json, explode, min as _min, sum as _sum, first, max as _max
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, LongType


config = {
    "spark_conf_name": "spark_4",
    "skip_success_check": True,
    "spark.storage.memoryFraction": 0.8,
}

# utils

In [None]:
def create_spark(spark_name: str):
    global spark
    spark = new_spark_session(f"cluster.layout.{spark_name}", config)
    global sc
    sc = spark.sparkContext
    sc.setLogLevel("ERROR")

def close_spark():
    spark.stop()

def get_s3_doctor(target_theme):
    partition_id = str(uuid.uuid4())
    current_time = datetime.now().strftime("%Y%m%d")
    # TODO 错误日志存放地址
    error_log_path = f"{ERROR_PATH}{target_theme}/{current_time}/{partition_id}.jsonl"
    s3_doc_writer = S3DocWriter(path=error_log_path)
    return s3_doc_writer

def crush_read_path(path):
    if isinstance(path, list):
        return [f.replace('s3:', 's3a:') for f in path]
    else:
        return path.replace('s3:', 's3a:')

BASE_PARTITION_ID = 0

COUNT_MAP = [
    ("1-1600", 1, 1600),
    ("1600-1.5w", 1600, 15000),
    ("1.5w-10w", 15000, 100000),
    ("10w", 100000, None)
]
DATA_SIZE_PER_BATCH = 1000000000

MAX_LAYOUTLIST_SIZE = 200
SIMILARITY_THRESHOLD = 0.95
TIMEOUT_SECONDS = 3600 * 5  # 超时时间
SIM_TIMEOUT_SECONDS = 60 * 2  # 超时时间
MAX_OUTPUT_ROW_SIZE = 1024 * 1024 * 1024 * 1.7
MAX_OUTPUT_FILE_SIZE = 1024 * 1024 * 1024 * 1  # 输出输出文件大小限制
NUM_PARTITIONS = 4000
WRITE_NUM_PARTITIONS = 2000
RATE_MAP = {"10w": 0.1, "1.5w-10w": 0.2, "1600-1.5w": 0.5, "1-1600": 0.7}

ERROR_PATH = "s3://xxx/"

INPUT_PATH = "s3a://xxxx/"
CHOOSE_DOMAIN_OUTPUT_PATH = "s3://xxx/"

CLUSTER_LAYOUT_BASE_OUTPUT_PATH = "s3://xxx/"
BASE_DOMAIN_PATH = "s3://xxx/"
BASE_BATCH_PATH = "s3://xxx/"

LAYOUT_SIM_BASE_OUTPUT_PATH = "s3://xxx/"

LAYOUT_INDEX_BASE_OUTPUT_PATH = "s3://xxx/"

# choose_domain

## utils

In [None]:
def divide_list_to_chunks(values, n_chunks):
    assignments = dict()
    chunks = [(0, i) for i in range(n_chunks)]
    heapq.heapify(chunks)

    indexed_values = sorted([(val["count"], idx, val["domain"]) for idx, val in enumerate(values)], key=lambda x: -x[0])
    for weight, idx, name in indexed_values:
        current_sum, chunk_id = heapq.heappop(chunks)
        assignments[name] = chunk_id
        new_sum = current_sum + weight
        heapq.heappush(chunks, (new_sum, chunk_id))
        yield Row(domain=name, partition_id=chunk_id)

def write_by_partitionid(_iter):
    detail_data = None
    s3_writer = None
    for index, detail_data in _iter.iterrows():
        line = {
            "domain": detail_data["domain"],
            "count": detail_data["count"],
            "partition_id": detail_data["partition_id"],
            "files": detail_data["files"].tolist(),
        }
        if s3_writer:
            s3_writer.write(line)
        else:
            partition_id = detail_data["partition_id"] + BASE_PARTITION_ID
            output_file = f"{CHOOSE_DOMAIN_OUTPUT_PATH}{count_data[0]}_{total_count}/{partition_id}.jsonl"
            s3_writer = S3DocWriter(output_file)
            s3_writer.write(line)

    if detail_data is not None:
        s3_writer.flush()
    yield {"write_size": index}

## choose domain

In [None]:
def get_choose_input_path(INPUT_PATH):
    try:
        with open("./already_exist.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = eval(content) if content else []
    except:
        already_exist = []
    input_path_lst = [f.replace("s3", "s3a") for f in list(list_s3_objects(INPUT_PATH, recursive=True)) if f.endswith(".jsonl") and f not in already_exist]
    if input_path_lst:
        with open("./already_exist.txt", "w", encoding="utf-8") as f:
            already_exist.extend(input_path_lst)
            f.write(str(already_exist))
    get_choose_input_df(input_path_lst)

def get_choose_input_df(input_path_lst):
    input_path_lst = crush_read_path(input_path_lst)
    df = spark.read.format("json").load(input_path_lst)
    df.cache()
    filter_batch_df(df)

def filter_batch_df(df: DataFrame):
    for i in [3,2,1,0]:
        global count_data
        count_data = COUNT_MAP[i]
        if count_data[2] is None:
            filter_df = df.filter(col("count") > count_data[1])
        else:
            filter_df = df.filter(col("count") > count_data[1]).filter(col("count") <= count_data[2])
        total_batch_count(filter_df)
        if total_count < 1:
            continue
        parse_partition_id(filter_df, i)

def total_batch_count(filter_df: DataFrame):
    global total_count
    total_count = filter_df.select(_sum("count")).collect()[0][0]

def parse_partition_id(filter_df: DataFrame, i: int):
    NUM_PARTITIONS = math.ceil(total_count/DATA_SIZE_PER_BATCH)
    if i == 0:
        repart_df = filter_df.select(["domain"]).repartition(NUM_PARTITIONS, col("domain"))
        partition_df = repart_df.withColumn("partition_id", spark_partition_id())
    else:
        # 分区数取决于total_count 和每批次的数据量级
        weight_datas = filter_df.select(["domain", "count"]).collect()
        partition_list = list(divide_list_to_chunks(weight_datas, NUM_PARTITIONS))
        partition_df = spark.createDataFrame(partition_list)
    join_to_write_choose(filter_df, partition_df)

def join_to_write_choose(filter_df: DataFrame, partition_df: DataFrame, ):
    df_with_weight = filter_df.join(partition_df, on="domain")
    
    output_schema = StructType([
        StructField('write_size', LongType(), True),
    ])
    
    @pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
    def pandas_udf_repartition(data_collected_series):
        result = write_by_partitionid(data_collected_series)
        if result:
            return pd.DataFrame(result)
        
    output_df = df_with_weight.groupby('partition_id').apply(pandas_udf_repartition)

def choose_main():
    spark_name = "choose_domain"
    create_spark(spark_name)
    get_choose_input_path(INPUT_PATH)
    close_spark()

# cluster_layout

## utils

In [None]:
def get_all_domain_data_cluster(_iter):
    s3_doc_writer = get_s3_doctor("get_feature")
    error_info = None
    for row in _iter:
        valid_count = row.valid_count
        file_d = row.file
        offset = file_d.offset
        record_count = file_d.record_count
        try:
            for detail_data in read_s3_by_offset_limit(file_d.filepath, offset, limit=record_count):
                try:
                    detail_data = json_loads(detail_data.value)
                    feature = get_feature(detail_data["html"])
                    if feature is None or not feature.get("tags"):
                        continue
                    layer_n, total_n = sum_tags(feature["tags"])
                    line = {
                        "date": detail_data["date"],
                        "track_id": detail_data["track_id"],
                        "url": detail_data["url"],
                        "raw_warc_path": detail_data["raw_warc_path"],
                        "domain": row.domain,
                        "sub_path": row.domain,
                        "valid_count": valid_count,
                        "feature": feature,
                        "layer_n": layer_n,
                        "total_n": total_n
                    }
                    line = json_dumps(line)
                    if len(line) < MAX_OUTPUT_ROW_SIZE:
                        yield Row(**{"value": line, "domain": row.domain, "valid_count": valid_count})
                    else:
                        error_info = {
                            "error_type": "EOFError",
                            "error_message": "Memory more than required for vector is (2147483648)",
                            "traceback": traceback.format_exc(),
                            "input_data": detail_data,
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        continue
                except Exception as e:
                    error_info = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "traceback": traceback.format_exc(),
                        "input_data": detail_data,
                        "timestamp": datetime.now().isoformat()
                    }
                    s3_doc_writer.write(error_info)
                    continue
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": str(row),
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
            continue
                        
    if error_info:
        s3_doc_writer.flush()

def crush_output_data_cluster(output_data):
    output_data_json = json_dumps(output_data)
    if len(output_data_json) < MAX_OUTPUT_ROW_SIZE:
        return output_data_json
    else:
        return output_data

def parse_batch_data_cluster(fpath):
    sample_list = []
    index = 0
    for domain_v in read_s3_rows(fpath, use_stream=True):
        index += 1
        if index != 0 and not index % 800:
            yield sample_list
            sample_list = []
        domain_data = json_loads(domain_v.value)
        try:
            lines = {
                "feature": domain_data["feature"],
                "layer_n": domain_data["layer_n"],
                "total_n": domain_data["total_n"],
                "track_id": domain_data["track_id"],
                "url": domain_data["url"],
                "domain": domain_data["domain"],
                "raw_warc_path": domain_data["raw_warc_path"],
                "date": domain_data["date"]
            }
            sample_list.append(lines)
        except:
            pass
    if sample_list:
        yield sample_list

def calculating_layout_cluster(current_host_name, sample_list):
    cluster_datas, layout_list = cluster_html_struct(sample_list)
    feature_dict = defaultdict(list)
    max_layer_n = cluster_datas[0]["max_layer_n"]
    # 每个layout类别抽取3个网页
    for r in cluster_datas:
        layout_id = r["layout_id"]
        if layout_id == -1:
            continue
        if len(feature_dict[layout_id]) < 3:
            cr = copy(r)
            feature_dict[layout_id].append(cr)
    if layout_list:
        layout_tmp_dict = crush_output_data_cluster({"domain": current_host_name, "sub_path": current_host_name, "feature_dict": dict(feature_dict), "layout_list": layout_list, "max_layer_n": max_layer_n})
        yield layout_tmp_dict

def parse_layout_cluster(domain_list):
    s3_doc_writer = get_s3_doctor("parse_layout")
    error_info = None
    for fpath in domain_list:
        try:
            for sample_list in parse_batch_data_cluster(fpath):
                try:
                    if len(sample_list) > 1:
                        current_host_name = sample_list[0]["domain"]
                        for line in func_timeout(TIMEOUT_SECONDS, calculating_layout_cluster, (current_host_name, sample_list,)):
                            if isinstance(line, str):
                                yield {"value": line}
                            elif isinstance(line, Dict):
                                error_info = {
                                    "error_type": "MemoryError by row",
                                    "error_message": "MemoryError by row",
                                    "traceback": traceback.format_exc(),
                                    "input_data": line,
                                    "timestamp": datetime.now().isoformat()
                                }
                                s3_doc_writer.write(error_info)
                                continue
                except FunctionTimedOut as e:
                    error_info = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "traceback": traceback.format_exc(),
                        "input_data": str(sample_list),
                        "timestamp": datetime.now().isoformat()
                    }
                    s3_doc_writer.write(error_info)
                    continue
                except Exception as e:
                    error_info = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "traceback": traceback.format_exc(),
                        "input_data": str(sample_list),
                        "timestamp": datetime.now().isoformat()
                    }
                    s3_doc_writer.write(error_info)
                    continue
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": fpath,
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()

def layout_similarity_cluster(layout_d1, layout_d2):
    max_layer_n = max(layout_d1["max_layer_n"], layout_d2["max_layer_n"])
    layout_last = layout_d1
    layout_last["max_layer_n"] = max_layer_n
    layout_list1 = layout_d1["layout_list"]
    layout_list2 = layout_d2["layout_list"]
    if len(layout_list1) > MAX_LAYOUTLIST_SIZE or len(layout_list2) > MAX_LAYOUTLIST_SIZE:
        return layout_d1 if len(layout_list1) > len(layout_list2) else layout_d2
    max_layout_id = max(layout_list1)
    feature_dict1 = layout_d1["feature_dict"]
    feature_dict2 = layout_d2["feature_dict"]
    ls_v = []
    [ls_v.extend(v) for k, v in feature_dict1.items()]
    exist_layout_num = 0
    for new_k, new_v in feature_dict2.items():
        add_tmp_dict_v = True
        for new_d in new_v:
            if any(similarity(new_d["feature"], h["feature"], max_layer_n) >= SIMILARITY_THRESHOLD for h in ls_v):
                add_tmp_dict_v = False
                exist_layout_num += 1
                break
        if add_tmp_dict_v is True:
            max_layout_id += 1
            layout_last["feature_dict"][str(max_layout_id)] = new_v
            layout_last["layout_list"].append(max_layout_id)
    return layout_last
                    
def merge_layout_cluster(domain_list):
    s3_doc_writer = get_s3_doctor("merge_layout")
    error_info = None
    pre_domain = {}
    domain_v = None
    # 两两进行合并
    index = 0
    for domain_f in domain_list:
        domain_paths = [f for f in list(list_s3_objects(domain_f, recursive=True)) if f.endswith(".jsonl")]
        for fpath in domain_paths:
            try:
                for domain_v in read_s3_rows(fpath, use_stream=True):
                    index += 1
                    if index == 1:
                        pre_domain = json_loads(domain_v.value)
                    else:
                        try:
                            pre_domain = layout_similarity_cluster(pre_domain, json_loads(domain_v.value))
                        except Exception as e:
                            error_info = {
                                "error_type": type(e).__name__,
                                "error_message": str(e),
                                "traceback": traceback.format_exc(),
                                "input_data": domain_v.value,
                                "timestamp": datetime.now().isoformat()
                            }
                            s3_doc_writer.write(error_info)
                            continue
            except Exception as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": domain_v.value,
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)
                
    if pre_domain:            
        output_data = json_dumps(pre_domain)
        if len(output_data) < MAX_OUTPUT_ROW_SIZE:
            yield {"layout_dict": output_data, "domain": pre_domain["domain"]}
        else:
            error_info = {
                "error_type": "EOFError",
                "error_message": "Memory more than required for vector is (2147483648)",
                "traceback": traceback.format_exc(),
                "input_data": domain_v.value,
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()


## main func

In [None]:
# get data by domain
def get_cluster_input_df(batch: str, path_list: List):
    batch = crush_read_path(batch)
    input_df = spark.read.format("json").load(batch)
    domain_df = parse_valid_data_cluster(input_df, path_list)
    domain_count = domain_df.count()
    parse_explode_df_cluster(domain_df, domain_count)
    
def parse_valid_data_cluster(input_df: DataFrame, path_list: List):
    data_range = path_list[-2].split("_")[0]
    data_count_rate = RATE_MAP.get(data_range, 1)
    domain_df = input_df.withColumn("valid_count", _round(col("count") * data_count_rate).cast("integer"))
    return domain_df

def parse_explode_df_cluster(domain_df: DataFrame, domain_count):
    explode_df = domain_df.withColumn("file", explode(col("files"))).drop("files")
    parse_get_feature_df_cluster(explode_df, domain_df, domain_count)

def parse_get_feature_df_cluster(explode_df: DataFrame, domain_df: DataFrame, domain_count):
    schema = StructType([
        StructField('value', StringType(), True),
        StructField('domain', StringType(), True),
        StructField('valid_count', IntegerType(), True),
    ])    
    feature_df = explode_df.repartition(NUM_PARTITIONS).rdd.mapPartitions(get_all_domain_data_cluster).toDF(schema)
    sample_by_valid_count_cluster(feature_df, domain_df, domain_count)

def sample_by_valid_count_cluster(feature_df: DataFrame, domain_df: DataFrame, domain_count):
    df_with_rand = feature_df.withColumn("rand", expr("rand()"))
    row_num_window_spec = Window.partitionBy("domain").orderBy(col("rand"))
    df_with_row_num = df_with_rand.withColumn("row_num", row_number().over(row_num_window_spec))
    domain_sample_df = df_with_row_num.filter(col("row_num") <= col("valid_count")).drop("rand", "row_num", "valid_count", "domain")
    write_domain_data(domain_sample_df)
    calculating_layout_every_batch_cluster(domain_df, domain_count)

def write_domain_data(domain_sample_df: DataFrame):
    output_file_size_gb = 0.3
    resize_func = spark_resize_file(output_file_size_gb)
    new_output_df = resize_func(domain_sample_df)
    
    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(new_output_df, DOMAIN_PATH, config)
    
def calculating_layout_every_batch_cluster(domain_df: DataFrame, domain_count):
    output_schema = StructType([
        StructField('value', StringType(), True),
    ])
    domain_lst = [f for f in list(list_s3_objects(DOMAIN_PATH, recursive=True)) if f.endswith(".jsonl")]
    if len(domain_lst) > NUM_PARTITIONS:
        page_content = sc.parallelize(domain_lst, NUM_PARTITIONS)
    else:
        page_content = sc.parallelize(domain_lst, len(domain_lst))
    layout_df = page_content.mapPartitions(parse_layout_cluster).toDF(output_schema)
    write_merge_layout_data(layout_df)
    merge_layout_by_layout_id_cluster(domain_df, domain_count)

def write_merge_layout_data(layout_df: DataFrame):
    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(layout_df, BATCH_PATH, config)

def merge_layout_by_layout_id_cluster(domain_df: DataFrame, domain_count):
    mer_output_schema = StructType([
        StructField('layout_dict', StringType(), True),
        StructField('domain', StringType(), True),
    ])

    batch_lst = list(list_s3_objects(BATCH_PATH, recursive=False))
    batch_page_content = sc.parallelize(batch_lst, len(batch_lst))
    merge_layout_df = batch_page_content.mapPartitions(merge_layout_cluster).toDF(mer_output_schema)
    join_to_write_cluster(merge_layout_df, domain_df)

def join_to_write_cluster(merge_layout_df: DataFrame, domain_df: DataFrame):
    join_df = domain_df.join(merge_layout_df, on="domain", how="left")

    struct_col = struct(join_df["domain"], join_df["count"], join_df["files"], join_df["layout_dict"])
    output_df = join_df.withColumn("value", to_json(struct_col)).select("value")

    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(output_df, OUTPUT_PATH, config)

def parse_cluster_path(batch):
    path_list = batch.split('/')
    global OUTPUT_PATH
    OUTPUT_PATH = f"{CLUSTER_LAYOUT_BASE_OUTPUT_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    global DOMAIN_PATH
    DOMAIN_PATH = f"{BASE_DOMAIN_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    global BATCH_PATH
    BATCH_PATH = F"{BASE_BATCH_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    return path_list

def parse_cluster_input_path(input_path):
    try:
        with open("./is_layout_complated.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = [i for i in content.split(",") if i] if content else []
    except:
        already_exist = []
    input_path_lst = [i for i in [f.replace("s3", "s3a") for f in list(list_s3_objects(input_path, recursive=True)) if f.endswith(".jsonl")] if i not in already_exist]
    return input_path_lst

def layout_main(big_batch):
    input_path_lst = parse_cluster_input_path(big_batch)
    for batch in input_path_lst:
        path_list = parse_cluster_path(batch)
        spark_name = '_'.join([path_list[-2], path_list[-1].replace('.jsonl', '')])
        create_spark(spark_name)
        get_cluster_input_df(batch, path_list)
        close_spark()
        with open("./is_layout_complated.txt", "a", encoding="utf-8") as f:
            f.write(",".join(batch) + ",")
    

# cluster_layout_little

## utils

In [None]:
def get_all_domain_data_cluster_little(_iter):
    s3_doc_writer = get_s3_doctor("get_feature")
    error_info = None
    for row in _iter:
        valid_count = row.valid_count
        file_d = row.file
        offset = file_d.offset
        record_count = file_d.record_count
        try:
            for detail_data in read_s3_by_offset_limit(file_d.filepath, offset, limit=record_count):
                try:
                    detail_data = json_loads(detail_data.value)
                    feature = get_feature(detail_data["html"])
                    if feature is None or not feature.get("tags"):
                        continue
                    layer_n, total_n = sum_tags(feature["tags"])
                    line = {
                        "date": detail_data["date"],
                        "track_id": detail_data["track_id"],
                        "url": detail_data["url"],
                        "raw_warc_path": detail_data["raw_warc_path"],
                        "domain": row.domain,
                        "sub_path": row.domain,
                        "valid_count": valid_count,
                        "feature": feature,
                        "layer_n": layer_n,
                        "total_n": total_n
                    }
                    line = json_dumps(line)
                    if len(line) < MAX_OUTPUT_ROW_SIZE:
                        yield Row(**{"value": line, "domain": row.domain, "valid_count": valid_count})
                    else:
                        error_info = {
                            "error_type": "EOFError",
                            "error_message": "Memory more than required for vector is (2147483648)",
                            "traceback": traceback.format_exc(),
                            "input_data": detail_data,
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        continue
                except Exception as e:
                    error_info = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "traceback": traceback.format_exc(),
                        "input_data": detail_data,
                        "timestamp": datetime.now().isoformat()
                    }
                    s3_doc_writer.write(error_info)
                    continue
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": str(row),
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
            continue
                        
    if error_info:
        s3_doc_writer.flush()

def crush_output_data_cluster_little(output_data):
    output_data_json = json_dumps(output_data)
    if len(output_data_json) < MAX_OUTPUT_ROW_SIZE:
        return output_data_json
    else:
        return output_data

def parse_batch_data_cluster_little(domain_list):
    sample_list = []
    for index, domain_v in domain_list.iterrows():
        if index != 0 and not index % 800:
            yield sample_list
            sample_list = []
        domain_data = json_loads(domain_v.value)
        try:
            lines = {
                "feature": domain_data["feature"],
                "layer_n": domain_data["layer_n"],
                "total_n": domain_data["total_n"],
                "track_id": domain_data["track_id"],
                "url": domain_data["url"],
                "domain": domain_data["domain"],
                "raw_warc_path": domain_data["raw_warc_path"],
                "date": domain_data["date"]
            }
            sample_list.append(lines)
        except:
            pass
    if sample_list:
        yield sample_list

def calculating_layout_cluster_little(current_host_name, sample_list):
    cluster_datas, layout_list = cluster_html_struct(sample_list)
    feature_dict = defaultdict(list)
    max_layer_n = cluster_datas[0]["max_layer_n"]
    # 每个layout类别抽取3个网页
    for r in cluster_datas:
        layout_id = r["layout_id"]
        if layout_id == -1:
            continue
        if len(feature_dict[layout_id]) < 3:
            cr = copy(r)
            feature_dict[layout_id].append(cr)
    if layout_list:
        layout_tmp_dict = crush_output_data_cluster_little({"domain": current_host_name, "sub_path": current_host_name, "feature_dict": dict(feature_dict), "layout_list": layout_list, "max_layer_n": max_layer_n})
        yield layout_tmp_dict

def parse_layout_cluster_little(domain_list):
    s3_doc_writer = get_s3_doctor("parse_layout")
    error_info = None
    current_host_name = domain_list["domain"].unique()[0]
    try:
        for sample_list in parse_batch_data_cluster_little(domain_list):
            try:
                if len(sample_list) > 1:
                    current_host_name = sample_list[0]["domain"]
                    for line in func_timeout(TIMEOUT_SECONDS, calculating_layout_cluster_little, (current_host_name, sample_list,)):
                        if isinstance(line, str):
                            yield {"value": line, "domain": current_host_name}
                        elif isinstance(line, Dict):
                            error_info = {
                                "error_type": "MemoryError by row",
                                "error_message": "MemoryError by row",
                                "traceback": traceback.format_exc(),
                                "input_data": line,
                                "timestamp": datetime.now().isoformat()
                            }
                            s3_doc_writer.write(error_info)
                            continue
            except FunctionTimedOut as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": str(sample_list),
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)
                continue
            except Exception as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": str(sample_list),
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)
                continue
    except Exception as e:
        error_info = {
            "error_type": type(e).__name__,
            "error_message": str(e),
            "traceback": traceback.format_exc(),
            "input_data": current_host_name,
            "timestamp": datetime.now().isoformat()
        }
        s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()

def layout_similarity_cluster_little(layout_d1, layout_d2):
    max_layer_n = max(layout_d1["max_layer_n"], layout_d2["max_layer_n"])
    layout_last = layout_d1
    layout_last["max_layer_n"] = max_layer_n
    layout_list1 = layout_d1["layout_list"]
    layout_list2 = layout_d2["layout_list"]
    if len(layout_list1) > MAX_LAYOUTLIST_SIZE or len(layout_list2) > MAX_LAYOUTLIST_SIZE:
        return layout_d1 if len(layout_list1) > len(layout_list2) else layout_d2
    max_layout_id = max(layout_list1)
    feature_dict1 = layout_d1["feature_dict"]
    feature_dict2 = layout_d2["feature_dict"]
    ls_v = []
    [ls_v.extend(v) for k, v in feature_dict1.items()]
    exist_layout_num = 0
    for new_k, new_v in feature_dict2.items():
        add_tmp_dict_v = True
        for new_d in new_v:
            if any(similarity(new_d["feature"], h["feature"], max_layer_n) >= SIMILARITY_THRESHOLD for h in ls_v):
                add_tmp_dict_v = False
                exist_layout_num += 1
                break
        if add_tmp_dict_v is True:
            max_layout_id += 1
            layout_last["feature_dict"][str(max_layout_id)] = new_v
            layout_last["layout_list"].append(max_layout_id)
    return layout_last
                    
def merge_layout_cluster_little(domain_list):
    s3_doc_writer = get_s3_doctor("merge_layout")
    error_info = None
    pre_domain = {}
    domain_v = None
    # 两两进行合并
    for index, domain_v in domain_list.iterrows():
        if index == 0:
            pre_domain = json_loads(domain_v.value)
        else:
            try:
                pre_domain = layout_similarity_cluster_little(pre_domain, json_loads(domain_v.value))
            except Exception as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": domain_v.value,
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)
                
    if pre_domain:
        output_data = json_dumps(pre_domain)
        if len(output_data) < MAX_OUTPUT_ROW_SIZE:
            return pd.DataFrame({"layout_dict": output_data, "domain": pre_domain["domain"]}, index=[0])
        else:
            error_info = {
                "error_type": "EOFError",
                "error_message": "Memory more than required for vector is (2147483648)",
                "traceback": traceback.format_exc(),
                "input_data": domain_v.value,
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()


## main func

In [None]:
# get data by domain
def get_input_df_cluster_little(batch: str):
    batch = crush_read_path(batch)
    input_df = spark.read.format("json").load(batch)
    domain_count = input_df.count()
    input_valid_df = parse_valid_data_cluster_little(input_df, batch)
    parse_explode_df_cluster_little(input_valid_df, domain_count)
    
def parse_valid_data_cluster_little(input_valid_df: DataFrame, batch: str):
    data_range = batch.split("/")[-2].split("_")[0]
    data_count_rate = RATE_MAP.get(data_range, 1)

    df_with_valid_c = input_valid_df.withColumn("valid_count", 
                                          when(col("count") <= 800, col("count"))
                                          .otherwise(_round(col("count") * data_count_rate).cast("integer"))
                                         )
    
    return df_with_valid_c

def parse_explode_df_cluster_little(domain_df: DataFrame, domain_count):
    explode_df = domain_df.withColumn("file", explode(col("files"))).drop("files")
    parse_get_feature_df_cluster_little(explode_df, domain_df, domain_count)

def parse_get_feature_df_cluster_little(explode_df: DataFrame, domain_df: DataFrame, domain_count):
    schema = StructType([
        StructField('value', StringType(), True),
        StructField('domain', StringType(), True),
        StructField('valid_count', IntegerType(), True),
    ])    
    feature_df = explode_df.repartition(NUM_PARTITIONS).rdd.mapPartitions(get_all_domain_data_cluster_little).toDF(schema)
    sample_by_valid_count_cluster_little(feature_df, domain_df, domain_count)

def sample_by_valid_count_cluster_little(feature_df: DataFrame, domain_df: DataFrame, domain_count):
    df_with_rand = feature_df.withColumn("rand", expr("rand()"))
    row_num_window_spec = Window.partitionBy("domain").orderBy(col("rand"))
    df_with_row_num = df_with_rand.withColumn("row_num", row_number().over(row_num_window_spec))
    domain_sample_df = df_with_row_num.filter(col("row_num") <= col("valid_count")).drop("rand", "row_num")
    calculating_layout_every_batch_cluster_little(domain_sample_df, domain_df, domain_count)

def calculating_layout_every_batch_cluster_little(domain_sample_df: DataFrame, domain_df: DataFrame, domain_count):
    output_schema = StructType([
        StructField('value', StringType(), True),
        StructField('domain', StringType(), True),
    ])
    
    @pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
    def pandas_udf_custom_process(data_collected_series):
        result = parse_layout_cluster_little(data_collected_series)
        if result:
            return pd.DataFrame(result)
    
    rep_df = domain_sample_df.repartition(NUM_PARTITIONS, col("domain"))
    layout_df = rep_df.groupby('domain').apply(pandas_udf_custom_process)
    merge_layout_by_layout_id_cluster_little(layout_df, domain_df, domain_count)


def merge_layout_by_layout_id_cluster_little(layout_df: DataFrame, domain_df: DataFrame, domain_count):
    output_schema = StructType([
        StructField('layout_dict', StringType(), True),
        StructField('domain', StringType(), True),
    ])

    merge_rep_df = layout_df.repartition(NUM_PARTITIONS, col("domain"))
    merge_layout_df = merge_rep_df.groupby('domain').applyInPandas(merge_layout_cluster_little, output_schema)
    join_to_write_cluster_little(merge_layout_df, domain_df)

def join_to_write_cluster_little(merge_layout_df: DataFrame, domain_df: DataFrame):
    join_df = domain_df.join(merge_layout_df, on="domain", how="left")

    struct_col = struct(join_df["domain"], join_df["count"], join_df["files"], join_df["layout_dict"])
    output_df = join_df.withColumn("value", to_json(struct_col)).select("value")

    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(output_df, OUTPUT_PATH, config)

def parse_path_cluster_little(batch):
    path_list = batch.split('/')
    global OUTPUT_PATH
    OUTPUT_PATH = f"{CLUSTER_LAYOUT_BASE_OUTPUT_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    return path_list

def parse_input_path_cluster_little(input_path):
    try:
        with open("./is_little_layout_complated.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = [i for i in content.split(",") if i] if content else []
    except:
        already_exist = []
    input_path_lst = [i for i in [f.replace("s3", "s3a") for f in list(list_s3_objects(input_path, recursive=True)) if f.endswith(".jsonl")] if i not in already_exist]
    return input_path_lst

def layout_little_main(big_batch):
    input_path_lst = parse_input_path_cluster_little(big_batch)
    for batch in input_path_lst:
        path_list = parse_path_cluster_little(batch)
        spark_name = '_'.join([path_list[-2], path_list[-1].replace('.jsonl', '')])
        create_spark(spark_name)
        get_input_df_cluster_little(batch)
        close_spark()
        with open("./is_little_layout_complated.txt", "a", encoding="utf-8") as f:
            f.write(batch + ",")
            

# layout_sim

## utils

In [None]:
def parse_output_data(row_data):
    row_data.update({"layout_id": '_'.join([row_data["url_host_name"], str(row_data["layout_id"])])})
    new_row_data_json = json_dumps(row_data)
    if len(new_row_data_json) < MAX_OUTPUT_ROW_SIZE:
        return {"value": new_row_data_json, "layout_id": row_data["layout_id"]}
    return None
        
def calculating_similarity(feature_dict, feature, max_layer_n):
    for k, v in feature_dict.items():
        if any(similarity(feature, h["feature"], max_layer_n) >= SIMILARITY_THRESHOLD for h in v):
            return int(k)
    return -2

def parse_similarity(_iter):
    s3_doc_writer = get_s3_doctor("parse_similarity")
    error_info = None
    for row in _iter:
        is_no_layout_id = False
        layout_dict = json_loads(row.layout_dict) if row.layout_dict else {}
        layout_list = layout_dict.get("layout_list", [])
        if not layout_list or (len(layout_list)==1 and layout_list[0]==-1):
            is_no_layout_id = True
        feature_dict = layout_dict.get("feature_dict", {})
        max_layer_n = layout_dict.get("max_layer_n", 5)
        domain = row.domain
        file_d = row.file
        offset = file_d.offset
        record_count = file_d.record_count
        try:
            for detail_data in read_s3_by_offset_limit(file_d.filepath, offset, limit=record_count):
                detail_data = json_loads(detail_data.value)
                if is_no_layout_id is True:
                    layout_id = -1
                else:
                    try:
                        feature = get_feature(detail_data["html"])
                        if feature is None or not feature.get("tags"):
                            layout_id = -2
                        else:
                            layout_id = func_timeout(SIM_TIMEOUT_SECONDS, calculating_similarity, (feature_dict, feature, max_layer_n,))
                    except FunctionTimedOut as e:
                        error_info = {
                            "error_type": type(e).__name__,
                            "error_message": str(e),
                            "traceback": traceback.format_exc(),
                            "input_data": str(detail_data),
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        layout_id = -2
                    except Exception as e:
                        error_info = {
                            "error_type": type(e).__name__,
                            "error_message": str(e),
                            "traceback": traceback.format_exc(),
                            "input_data": str(detail_data),
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        layout_id = -2
                line = {
                    "track_id": detail_data["track_id"],
                    "html": detail_data["html"],
                    "url": detail_data["url"],
                    "layout_id": layout_id,
                    "max_layer_n": max_layer_n,
                    "url_host_name": domain,
                    "raw_warc_path": detail_data["raw_warc_path"]
                }
                json_line = parse_output_data(line)
                if json_line is not None:
                    yield json_line
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": str(row),
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
            continue
                    
    if error_info:
        s3_doc_writer.flush()

def save_s3_by_layout(outdata_list):
    s3_doc_writer = get_s3_doctor("similarity_write")
    error_info = None
    json_line = None
    s3_writer = None
    index = 0
    offset = 0
    offset_index = 0
    for index, row in enumerate(outdata_list):
        try:
            if offset > MAX_OUTPUT_FILE_SIZE:
                if json_line:
                    s3_writer.flush()
                    s3_writer = None
                    offset = 0
            json_line = json_loads(row.value)
            json_line["offset"] = offset
            if s3_writer:
                offset += s3_writer.write(json_line)
            else:
                partition_id = row.partition_id
                offset_index = offset_index + 1
                output_file = f"{OUTPUT_PATH}{partition_id}_{offset_index}.jsonl.gz"
                s3_writer = S3DocWriter(output_file)
                offset += s3_writer.write(json_line)
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": row.value,
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
            continue

    if json_line:
        s3_writer.flush()
    if error_info:
        s3_doc_writer.flush()
    yield {"write_size": index}


## main func

In [None]:
def parse_input_path_sim(input_path):
    try:
        with open("./is_similarity_complated.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = [i for i in content.split(",") if i] if content else []
    except:
        already_exist = []
    input_path_lst = [i for i in [f.replace("s3", "s3a") for f in list(list_s3_objects(input_path, recursive=False))] if i not in already_exist]
    return input_path_lst

def parse_path_sim(batch):
    path_list = batch.split('/')
    path1 = path_list[-3].split("_")[0]
    global OUTPUT_PATH
    OUTPUT_PATH = f"{LAYOUT_SIM_BASE_OUTPUT_PATH}{path1}/{path_list[-2]}/"
    return path_list

def get_domain_df(batch):
    batch = crush_read_path(batch)
    input_f_df = spark.read.format("json").load(batch)
    input_df = input_f_df.withColumn("file", explode(col("files"))).drop("files")
    similarity_every_domain(input_df)

def similarity_every_domain(input_df: DataFrame):
    schema = StructType([
        StructField('value', StringType(), True),
        StructField('layout_id', StringType(), True),
    ])
    
    all_domain_df = input_df.repartition(NUM_PARTITIONS).rdd.mapPartitions(parse_similarity).toDF(schema)
    write_by_layoutid(all_domain_df)

def write_by_layoutid(all_domain_df: DataFrame):
    final_df = all_domain_df.repartition(WRITE_NUM_PARTITIONS, col("layout_id")).withColumn("partition_id", spark_partition_id()).sortWithinPartitions(col("layout_id"))
    out_df = final_df.rdd.mapPartitions(save_s3_by_layout)
    out_df.count()

def sim_main():
    big_batch_lst = list(list_s3_objects(CLUSTER_LAYOUT_BASE_OUTPUT_PATH, recursive=False))
    for big_batch in big_batch_lst:
        input_path_lst = parse_input_path_sim(big_batch)
        for batch in input_path_lst:
            path_list = parse_path_sim(batch)
            spark_name = 'sim.' + '_'.join([path_list[-3], path_list[-2]])
            create_spark(spark_name)
            get_domain_df(batch)
            close_spark()   
            with open("./is_similarity_complated.txt", "a", encoding="utf-8") as f:
                f.write(batch + ",")


# layout_index

## main func

In [None]:
def parse_input_path_index(input_path: str):
    try:
        with open("./is_index_complated.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = [i for i in content.split(",") if i] if content else []
    except:
        already_exist = []
    input_path_lst = [i for i in list(list_s3_objects(input_path, recursive=False)) if i not in already_exist]
    return input_path_lst

def create_index_df(batch: str, path_list: List):
    input_path_lst = [i for i in list(list_s3_objects(batch, recursive=True)) if i.endswith(".jsonl.gz")]

    with_length_v_df = read_any_path(spark, ','.join(input_path_lst), config)
    schema = StructType([
        StructField("url_host_name", StringType(), True),
        StructField("layout_id", StringType(), True),
        StructField("offset", LongType(), True),
    ])
    
    df_with_struct = with_length_v_df.withColumn("json_struct", from_json(with_length_v_df.value, schema))
    with_length_df = df_with_struct.select("json_struct.*", col("filename").alias("filepath"))
    data_to_index(with_length_df, path_list)

def data_to_index(with_length_df: DataFrame, path_list: List):
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    group_df = with_length_df.groupBy(["url_host_name", "layout_id", "filepath"]).agg(
        (_max(col("offset")) - _min(col("offset"))).alias("length"),
        _min(col("offset")).alias("offset"),
        _count("*").alias("record_count")
    ).sort("url_host_name", "layout_id", "filepath").withColumn("timestamp", lit(current_time))
    

    file_df = group_df.withColumn("file", 
          struct(
            col("filepath").alias("filepath"),
            col("length").cast("long").alias("length"),
            col("offset").cast("long").alias("offset"),
            col("record_count").cast("long").alias("record_count"),
            col("timestamp").alias("timestamp")
          )
        ).select(["url_host_name", "layout_id", "file", "record_count"])


    result_df = file_df.groupBy("layout_id") \
        .agg(
            _sum("record_count").alias("count"),
            collect_list("file").alias("files"),
            first("url_host_name").alias("url_host_name")
        ).orderBy("count", ascending=False)
    write_by_two(result_df, path_list)

def write_by_two(result_df: DataFrame, path_list: List):
    struct_col = struct(result_df["layout_id"], result_df["count"], result_df["files"], result_df["url_host_name"])
    output_df = result_df.withColumn("value", to_json(struct_col)).select("value")

    output_acc = S3UploadAcc(spark.sparkContext)
    output_df.repartition(1).foreachPartition(upload_to_s3(OUTPUT_PATH, "jsonl", output_acc, prefix = path_list[-2]))

def parse_path_index(batch):
    path_list = batch.split('/')
    global OUTPUT_PATH
    OUTPUT_PATH = f"{LAYOUT_INDEX_BASE_OUTPUT_PATH}{path_list[-3]}/"
    return path_list
    
def index_main():
    big_batch_lst = list(list_s3_objects(LAYOUT_SIM_BASE_OUTPUT_PATH, recursive=False))
    for big_batch in big_batch_lst:
        batch_lst = parse_input_path_index(big_batch)
        for batch in batch_lst:
            path_list = parse_path_index(batch)
            spark_name = "index." + ".".join([path_list[-3], path_list[-2]])
            create_spark(spark_name)
            create_index_df(batch, path_list)
            close_spark()
            with open("./is_index_complated.txt", "a", encoding="utf-8") as f:
                f.write(batch + ",")


# start

In [None]:
choose_main()

# 条件判断
big_batch_lst = list(list_s3_objects(CHOOSE_DOMAIN_OUTPUT_PATH, recursive=False))
for big_batch in big_batch_lst:
    end_path = big_batch.split("/")[-2]
    if end_path.startswith("1-1600"):
        layout_little_main(big_batch)
    else:
        layout_main(big_batch)

sim_main()

index_main()