In [None]:
from xinghe.spark import *
from app.common.json_util import *
from xinghe.s3 import *
from xinghe.s3.read import *
from xinghe.ops.spark import spark_resize_file

config = {
    "spark_conf_name": "spark_4",
    "skip_success_check": True,
    "spark.executorEnv.LLM_WEB_KIT_CFG_PATH": "/share/xxx.jsonc",
    "spark.yarn.queue": "pipeline.clean",
    "spark.executor.memory": "40g",
}
from llm_web_kit.libs.standard_utils import compress_and_decompress_str
from llm_web_kit.html_layout.html_layout_cosin import cluster_html_struct, get_feature, similarity, sum_tags

import base64
import re
import random
import time
import sys
import pickle
import zlib
import uuid
import traceback
import numpy as np
import pandas as pd
from typing import List, Dict, Union
from copy import deepcopy, copy
from urllib.parse import quote, unquote, urlparse, parse_qs
from datetime import datetime
from lxml import html
from collections import defaultdict
from func_timeout import FunctionTimedOut, func_timeout

from pyspark.sql import Row, DataFrame
from pyspark.sql.functions import row_number, col, collect_list, struct, expr, rand, count, pandas_udf, PandasUDFType, \
    round as _round, lit, to_json, explode
from pyspark.sql.types import StructType, StructField, IntegerType, BinaryType, StringType

import os

os.environ["LLM_WEB_KIT_CFG_PATH"] = "/share/xxx.jsonc"

TIMEOUT_SECONDS = 3600 * 5
MAX_OUTPUT_ROW_SIZE = 1024 * 1024 * 1024 * 1.7
SIMILARITY_THRESHOLD = 0.95
RATE_MAP = {"10w": 0.1, "1.5w-10w": 0.2, "1600-1.5w": 0.5, "1-1600": 1, "500-1600": 1}
NUM_PARTITIONS = 100000
ERROR_PATH = "s3://xxx/"
INPUT_PATH = "s3://xxx/"
BASE_OUTPUT_PATH = "s3://xxx/"
BASE_DOMAIN_PATH = "s3://xxx/"

# utils

In [None]:
def get_s3_doctor(target_theme):
    partition_id = str(uuid.uuid4())
    current_time = datetime.now().strftime("%Y%m%d")
    error_log_path = f"{ERROR_PATH}{target_theme}/{current_time}/{partition_id}.jsonl"
    s3_doc_writer = S3DocWriter(path=error_log_path)
    return s3_doc_writer


def get_all_domain_data(_iter):
    s3_doc_writer = get_s3_doctor("get_feature")
    error_info = None
    for row in _iter:
        valid_count = row.valid_count
        file_d = row.file
        offset = file_d.offset
        length = file_d.length
        record_count = file_d.record_count
        idx = 0
        try:
            for detail_data in read_s3_lines_with_range(file_d.filepath, use_stream=True,
                                                        bytes_range=(offset, offset + length)):
                idx += 1
                if idx > record_count:
                    break
                try:
                    detail_data = json_loads(detail_data)
                    feature = get_feature(detail_data["html"])
                    if feature is None or not feature.get("tags"):
                        continue
                    layer_n, total_n = sum_tags(feature["tags"])
                    line = {
                        "date": detail_data["date"],
                        "track_id": detail_data["track_id"],
                        "url": detail_data["url"],
                        "raw_warc_path": detail_data["raw_warc_path"],
                        "domain": row.domain,
                        "sub_path": row.domain,
                        "valid_count": valid_count,
                        "feature": feature,
                        "layer_n": layer_n,
                        "total_n": total_n
                    }
                    line = json_dumps(line)
                    if len(line) < MAX_OUTPUT_ROW_SIZE:
                        # compress_line = compress_and_decompress_str(line)
                        yield Row(**{"value": line, "domain": row.domain, "valid_count": valid_count})
                    else:
                        error_info = {
                            "error_type": "EOFError",
                            "error_message": "Memory more than required for vector is (2147483648)",
                            "traceback": traceback.format_exc(),
                            "input_data": detail_data,
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        continue
                except Exception as e:
                    error_info = {
                        "error_type": type(e).__name__,
                        "error_message": str(e),
                        "traceback": traceback.format_exc(),
                        "input_data": detail_data,
                        "timestamp": datetime.now().isoformat()
                    }
                    s3_doc_writer.write(error_info)
                    continue
        except Exception as e:
            error_info = {
                "error_type": type(e).__name__,
                "error_message": str(e),
                "traceback": traceback.format_exc(),
                "input_data": str(row),
                "timestamp": datetime.now().isoformat()
            }
            s3_doc_writer.write(error_info)
            continue

    if error_info:
        s3_doc_writer.flush()


def crush_output_data(output_data):
    output_data_json = json_dumps(output_data)
    if len(output_data_json) < MAX_OUTPUT_ROW_SIZE:
        return {"value": output_data_json, "domain": output_data["domain"]}


def parse_batch_data(fpath):
    sample_list = []
    index = 0
    for domain_v in read_s3_rows(fpath, use_stream=True):
        index += 1
        if index != 0 and not index % 800:
            yield sample_list
            sample_list = []
        domain_data = json_loads(domain_v.value)
        try:
            lines = {
                "feature": domain_data["feature"],
                "layer_n": domain_data["layer_n"],
                "total_n": domain_data["total_n"],
                "track_id": domain_data["track_id"],
                "url": domain_data["url"],
                "domain": domain_data["domain"],
                "raw_warc_path": domain_data["raw_warc_path"],
                "date": domain_data["date"]
            }
            sample_list.append(lines)
        except:
            pass
    if sample_list:
        yield sample_list


def calculating_layout(current_host_name, sample_list):
    cluster_datas, layout_list = cluster_html_struct(sample_list)
    feature_dict = defaultdict(list)
    max_layer_n = cluster_datas[0]["max_layer_n"]
    # 每个layout类别抽取3个网页
    for r in cluster_datas:
        layout_id = r["layout_id"]
        if layout_id == -1:
            continue
        if len(feature_dict[layout_id]) < 3:
            cr = copy(r)
            feature_dict[layout_id].append(cr)
    if layout_list:
        layout_tmp_dict = crush_output_data(
            {"domain": current_host_name, "feature_dict": dict(feature_dict), "layout_list": layout_list,
             "max_layer_n": max_layer_n})
        if layout_tmp_dict:
            yield layout_tmp_dict


def parse_layout(domain_list):
    s3_doc_writer = get_s3_doctor("parse_layout")
    error_info = None
    for domain_f in domain_list:
        domain_paths = [f for f in list(list_s3_objects(domain_f, recursive=True)) if f.endswith(".jsonl")]
        for fpath in domain_paths:
            try:
                for sample_list in parse_batch_data(fpath):
                    try:
                        if len(sample_list) > 1:
                            current_host_name = sample_list[0]["domain"]
                            for line in func_timeout(TIMEOUT_SECONDS, calculating_layout,
                                                     (current_host_name, sample_list,)):
                                yield line
                    except FunctionTimedOut as e:
                        error_info = {
                            "error_type": type(e).__name__,
                            "error_message": str(e),
                            "traceback": traceback.format_exc(),
                            "input_data": str(sample_list),
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        continue
                    except Exception as e:
                        error_info = {
                            "error_type": type(e).__name__,
                            "error_message": str(e),
                            "traceback": traceback.format_exc(),
                            "input_data": str(sample_list),
                            "timestamp": datetime.now().isoformat()
                        }
                        s3_doc_writer.write(error_info)
                        continue
            except Exception as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": fpath,
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()


def layout_similarity(layout_d1, layout_d2):
    max_layer_n = max(layout_d1["max_layer_n"], layout_d2["max_layer_n"])
    layout_last = layout_d1
    layout_last["max_layer_n"] = max_layer_n
    layout_list1 = layout_d1["layout_list"]
    max_layout_id = max(layout_list1)
    feature_dict1 = layout_d1["feature_dict"]
    feature_dict2 = layout_d2["feature_dict"]
    ls_v = []
    [ls_v.extend(v) for k, v in feature_dict1.items()]
    exist_layout_num = 0
    for new_k, new_v in feature_dict2.items():
        add_tmp_dict_v = True
        for new_d in new_v:
            if any(similarity(new_d["feature"], h["feature"], max_layer_n) >= SIMILARITY_THRESHOLD for h in ls_v):
                add_tmp_dict_v = False
                exist_layout_num += 1
                break
        if add_tmp_dict_v is True:
            max_layout_id += 1
            layout_last["feature_dict"][str(max_layout_id)] = new_v
            layout_last["layout_list"].append(max_layout_id)
    return layout_last


def merge_layout(domain_list):
    s3_doc_writer = get_s3_doctor("merge_layout")
    error_info = None
    pre_domain = {}
    domain_v = None
    # 两两进行合并
    for index, domain_v in domain_list.iterrows():
        if index == 0:
            pre_domain = json_loads(domain_v.value)
        else:
            try:
                pre_domain = layout_similarity(pre_domain, json_loads(domain_v.value))
            except Exception as e:
                error_info = {
                    "error_type": type(e).__name__,
                    "error_message": str(e),
                    "traceback": traceback.format_exc(),
                    "input_data": domain_v.value,
                    "timestamp": datetime.now().isoformat()
                }
                s3_doc_writer.write(error_info)

    output_data = json_dumps(pre_domain)
    if len(output_data) < MAX_OUTPUT_ROW_SIZE:
        return pd.DataFrame({"layout_dict": output_data, "domain": pre_domain["domain"]}, index=[0])
    else:
        error_info = {
            "error_type": "EOFError",
            "error_message": "Memory more than required for vector is (2147483648)",
            "traceback": traceback.format_exc(),
            "input_data": domain_v.value,
            "timestamp": datetime.now().isoformat()
        }
        s3_doc_writer.write(error_info)
    if error_info:
        s3_doc_writer.flush()


# main func

In [None]:
def create_spark(spark_name: str):
    global spark
    spark = new_spark_session(f"layout.write.{spark_name}", config)
    global sc
    sc = spark.sparkContext
    sc.setLogLevel("ERROR")


def get_input_df(batch: str):
    input_df = spark.read.format("json").load(batch)
    domain_df = parse_valid_data(input_df, batch)
    domain_count = domain_df.count()
    parse_explode_df(domain_df, domain_count)


def parse_valid_data(input_df: DataFrame, batch: str):
    data_range = batch.split("/")[-2].split("_")[0]
    data_count_rate = RATE_MAP.get(data_range, 1)
    domain_df = input_df.withColumn("valid_count", _round(col("count") * data_count_rate).cast("integer"))
    return domain_df


def parse_explode_df(domain_df: DataFrame, domain_count):
    explode_df = domain_df.withColumn("file", explode(col("files"))).drop("files")
    parse_get_feature_df(explode_df, domain_df, domain_count)


def parse_get_feature_df(explode_df: DataFrame, domain_df: DataFrame, domain_count):
    schema = StructType([
        StructField('value', StringType(), True),
        StructField('domain', StringType(), True),
        StructField('valid_count', IntegerType(), True),
    ])
    feature_df = explode_df.repartition(NUM_PARTITIONS).rdd.mapPartitions(get_all_domain_data).toDF(schema)
    sample_by_valid_count(feature_df, domain_df, domain_count)


def sample_by_valid_count(feature_df: DataFrame, domain_df: DataFrame, domain_count):
    df_with_rand = feature_df.withColumn("rand", expr("rand()"))
    row_num_window_spec = Window.partitionBy("domain").orderBy(col("rand"))
    df_with_row_num = df_with_rand.withColumn("row_num", row_number().over(row_num_window_spec))
    domain_sample_df = df_with_row_num.filter(col("row_num") <= col("valid_count")).drop("rand", "row_num",
                                                                                         "valid_count", "domain")
    write_domain_data(domain_sample_df)
    calculating_layout_every_batch(domain_df, domain_count)


def write_domain_data(domain_sample_df: DataFrame):
    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(domain_sample_df, DOMAIN_PATH, config)


def calculating_layout_every_batch(domain_df: DataFrame, domain_count):
    output_schema = StructType([
        StructField('value', StringType(), True),
        StructField('domain', StringType(), True),
    ])
    domain_lst = list(list_s3_objects(DOMAIN_PATH, recursive=False))
    page_content = sc.parallelize(domain_lst, len(domain_lst))
    layout_df = page_content.mapPartitions(parse_layout).toDF(output_schema)
    merge_layout_by_layout_id(layout_df, domain_df, domain_count)


def merge_layout_by_layout_id(layout_df: DataFrame, domain_df: DataFrame, domain_count):
    output_schema = StructType([
        StructField('layout_dict', StringType(), True),
        StructField('domain', StringType(), True),
    ])

    if domain_count < NUM_PARTITIONS:
        merge_rep_df = layout_df.repartition(domain_count, col("domain"))
    else:
        merge_rep_df = layout_df.repartition(NUM_PARTITIONS, col("domain"))
    merge_layout_df = merge_rep_df.groupby('domain').applyInPandas(merge_layout, output_schema)
    join_to_write(merge_layout_df, domain_df)


def join_to_write(merge_layout_df: DataFrame, domain_df: DataFrame):
    join_df = domain_df.join(merge_layout_df, on="domain", how="left")

    struct_col = struct(join_df["domain"], join_df["count"], join_df["files"], join_df["layout_dict"])
    output_df = join_df.withColumn("value", to_json(struct_col)).select("value")

    config["skip_output_version"] = True
    config['skip_output_check'] = True
    write_any_path(output_df, OUTPUT_PATH, config)


def close_spark():
    spark.stop()


def parse_path(batch):
    path_list = batch.split('/')
    global OUTPUT_PATH
    OUTPUT_PATH = f"{BASE_OUTPUT_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    global DOMAIN_PATH
    DOMAIN_PATH = f"{BASE_DOMAIN_PATH}{path_list[-2]}/{path_list[-1].replace('.jsonl', '')}/"
    return path_list


def parse_input_path(input_path):
    try:
        with open("./is_layout_complated.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = [i for i in content.split(",") if i] if content else []
    except:
        already_exist = []
    input_path_lst = [i for i in [f.replace("s3", "s3a") for f in list(list_s3_objects(input_path, recursive=True)) if
                                  f.endswith(".jsonl")] if i not in already_exist]
    return input_path_lst


def main():
    input_path_lst = parse_input_path(INPUT_PATH)
    for batch in input_path_lst:
        path_list = parse_path(batch)
        spark_name = '_'.join([path_list[-2], path_list[-1].replace('.jsonl', '')])
        create_spark(spark_name)
        get_input_df(batch)
        close_spark()
        with open("./is_layout_complated.txt", "a", encoding="utf-8") as f:
            f.write(batch + ",")

In [None]:
main()