In [None]:
from xinghe.spark import *
from app.common.json_util import *
from xinghe.s3 import *

config = {
    "spark_conf_name": "spark_4",
    "skip_success_check": True,
}

import pandas as pd
import time
import random
import uuid
import heapq
from pyspark.sql import Row, DataFrame
from pyspark.sql.functions import from_json, row_number, col, sum as _sum, to_json, struct, pandas_udf, PandasUDFType, \
    lit, spark_partition_id
from pyspark.sql.types import StructType, StructField, IntegerType, LongType

BASE_PARTITION_ID = 0
OUTPUT_PATH = "s3://xxx/"
INPUT_PATH = "s3://xxx/"
COUNT_MAP = [
    ("1-1600", 1, 1600),
    ("1600-1.5w", 1600, 15000),
    ("1.5w-10w", 15000, 100000),
    ("10w", 100000, None)
]
DATA_SIZE_PER_BATCH = 1000000000

# utils

In [None]:
def divide_list_to_chunks(values, n_chunks):
    assignments = dict()
    chunks = [(0, i) for i in range(n_chunks)]
    heapq.heapify(chunks)

    indexed_values = sorted([(val["count"], idx, val["domain"]) for idx, val in enumerate(values)], key=lambda x: -x[0])
    for weight, idx, name in indexed_values:
        current_sum, chunk_id = heapq.heappop(chunks)
        assignments[name] = chunk_id
        new_sum = current_sum + weight
        heapq.heappush(chunks, (new_sum, chunk_id))
        yield Row(domain=name, partition_id=chunk_id)


def write_by_partitionid(_iter):
    detail_data = None
    s3_writer = None
    for index, detail_data in _iter.iterrows():
        line = {
            "domain": detail_data["domain"],
            "count": detail_data["count"],
            "partition_id": detail_data["partition_id"],
            "files": detail_data["files"].tolist(),
        }
        if s3_writer:
            s3_writer.write(line)
        else:
            partition_id = detail_data["partition_id"] + BASE_PARTITION_ID
            output_file = f"{OUTPUT_PATH}{count_data[0]}_{total_count}/{partition_id}.jsonl"
            s3_writer = S3DocWriter(output_file)
            s3_writer.write(line)

    if detail_data is not None:
        s3_writer.flush()
    yield {"write_size": index}

# choose domain

In [None]:
def create_spark(spark_name: str):
    global spark
    spark = new_spark_session(f"layout.{spark_name}", config)
    sc = spark.sparkContext
    sc.setLogLevel("ERROR")


def get_input_path(INPUT_PATH):
    try:
        with open("./already_exist.txt", "r", encoding="utf-8") as f:
            content = f.read()
            already_exist = eval(content) if content else []
    except:
        already_exist = []
    input_path_lst = [f.replace("s3", "s3a") for f in list(list_s3_objects(INPUT_PATH, recursive=True)) if
                      f.endswith(".jsonl") and f not in already_exist]
    if input_path_lst:
        with open("./already_exist.txt", "w", encoding="utf-8") as f:
            already_exist.extend(input_path_lst)
            f.write(str(already_exist))
    get_input_df(input_path_lst)


def get_input_df(input_path_lst):
    df = spark.read.format("json").load(input_path_lst)
    df.cache()
    filter_batch_df(df)


def filter_batch_df(df: DataFrame):
    for i in range(4):
        global count_data
        count_data = COUNT_MAP[i]
        if count_data[2] is None:
            filter_df = df.filter(col("count") > count_data[1])
        else:
            filter_df = df.filter(col("count") > count_data[1]).filter(col("count") <= count_data[2])
        total_batch_count(filter_df)
        parse_partition_id(filter_df, i)


def total_batch_count(filter_df: DataFrame):
    global total_count
    total_count = filter_df.select(_sum("count")).collect()[0][0]


def parse_partition_id(filter_df: DataFrame, i: int):
    NUM_PARTITIONS = round(total_count / DATA_SIZE_PER_BATCH)
    if i == 0:
        repart_df = filter_df.select(["domain"]).repartition(NUM_PARTITIONS, col("domain"))
        partition_df = repart_df.withColumn("partition_id", spark_partition_id())
    else:
        # 分区数取决于total_count 和每批次的数据量级
        weight_datas = filter_df.select(["domain", "count"]).collect()
        partition_list = list(divide_list_to_chunks(weight_datas, NUM_PARTITIONS))
        partition_df = spark.createDataFrame(partition_list)
    join_to_write(filter_df, partition_df)


def join_to_write(filter_df: DataFrame, partition_df: DataFrame, ):
    df_with_weight = filter_df.join(partition_df, on="domain")

    output_schema = StructType([
        StructField('write_size', LongType(), True),
    ])

    @pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
    def pandas_udf_repartition(data_collected_series):
        result = write_by_partitionid(data_collected_series)
        if result:
            return pd.DataFrame(result)

    output_df = df_with_weight.groupby('partition_id').apply(pandas_udf_repartition)
    output_df.count()


def close_spark():
    spark.stop()


def main():
    spark_name = "choose_domain"
    create_spark(spark_name)
    get_input_path(INPUT_PATH)
    close_spark()

In [None]:
main()