In [None]:
import datasets
# from transformers import AutoTokenizer
from tqdm import tqdm
import pyarrow.parquet as pq
import pyarrow as pa
from functools import partial 
import os
NUM_PROC=max(os.cpu_count()-2,1)
load_dataset=partial(datasets.load_dataset,num_proc=NUM_PROC,cache_dir="cache")

In [None]:
!mkdir -p output/metadata_map

In [None]:
zalo_ds:datasets.Dataset=load_dataset(path="json",data_files="./data/zalo/legal_corpus.json")
print(zalo_ds["train"][0])
nds1=[]
originals=[]
dataset_map="zalo_legal_corpus"
for law in tqdm(zalo_ds["train"]):
    for article in law["articles"]:
        article["law_id"]=law["law_id"]
        article["oid"]=len(nds1)
        originals.append(article)
        nds1.append({
            "text":article["text"],
            "dataset":dataset_map,
            "oid":article["oid"]
        })
pq.write_table(pa.Table.from_pylist(originals),f"output/metadata_map/{dataset_map}")
del originals

In [None]:
sft_ds:datasets.Dataset=load_dataset(path="parquet",data_files="./data/sft_v2.parquet",split="train")
dataset_map=f"sft_tvpl"
originals=[]
sft_ds[0]
for item in tqdm(sft_ds):
    # item["text"]=" ".join(item["context"])
    item["contextoid"]=[]
    for text in item["context"]:
        item["contextoid"].append(len(nds1))
        nds1.append({
            "text":text,
            "dataset":dataset_map,
            "oid":item["contextoid"][-1]
        })
    originals.append(item)
pq.write_table(pa.Table.from_pylist(originals),f"output/metadata_map/{dataset_map}")
del originals

In [None]:
tvpl_ds = load_dataset("parquet", data_files="data/TVPL/structured_data_doc.parquet")
print(tvpl_ds["train"][0])
dataset_map="tvpl_structured"
originals=[]
for item in tqdm(tvpl_ds["train"]):
    children=item["data"]
    for id in range(len(children)):
        children[id]["text"]=children[id]["full_text"]
        del children[id]["full_text"]
        children[id]["oid"]=len(nds1)
        nds1.append({
            "text":children[id]["text"],
            "dataset":dataset_map,
            "oid":children[id]["oid"]
        })
    item["data"]=children 
    originals.append(item)
pq.write_table(pa.Table.from_pylist(originals),f"output/metadata_map/{dataset_map}")
del originals

In [None]:
# with open("./temp/temp_legal_corpus_withstpvl.json","w+") as f:
#     f.write(json.dumps(nds1))
#     f.write("\n")
#     f.flush() 
# union_ds=datasets.Dataset.from_list(nds1)
!mkdir -p ./temp
pq.write_table(pa.Table.from_pylist(nds1),"./temp/temp_legal_corpus_withstpvl")
del nds1
del zalo_ds
del sft_ds
del tvpl_ds

In [None]:
ds=load_dataset(path="parquet",data_files="./temp/temp_legal_corpus_withstpvl")["train"]
len(ds)

In [None]:
JACCARD_THRESH=0.52
MINHASH_THRESH=0.4
JACCARD_MODE="PREVIOUS"
OUTPUT_PATH="output/merged_zalo_sftlaw_tvpl_3"

In [None]:
print(
    f'python -m minhash \
  --path "parquet" \
  --data_files "temp/temp_legal_corpus_withstpvl"\
  --name "gl" \
  --split "train" \
  --cache_dir "./cache" \
  --output {OUTPUT_PATH} \
  --column "text" \
  --threshold {MINHASH_THRESH} \
  --min_length 5\
  --batch_size 10000 \
  --lsh_false_positive_thresh_weight 0.5 \
  --jaccard_mode {JACCARD_MODE} \
  --jaccard_thresh {JACCARD_THRESH}\
    --num_proc {NUM_PROC}'
)

In [None]:
!python -m minhash \
  --path "parquet" \
  --data_files "temp/temp_legal_corpus_withstpvl"\
  --name "gl" \
  --split "train" \
  --cache_dir "./cache" \
  --output {OUTPUT_PATH} \
  --column "text" \
  --threshold {MINHASH_THRESH} \
  --min_length 5\
  --batch_size 10000 \
  --lsh_false_positive_thresh_weight 0.5 \
  --jaccard_mode {JACCARD_MODE} \
  --jaccard_thresh {JACCARD_THRESH}\
  --num_proc {NUM_PROC}


In [None]:
clid=set()
merged_ds=load_dataset("parquet",data_files=f"{OUTPUT_PATH}_reindexed({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")["train"]
for item in tqdm(merged_ds):
    clid.add(item["__cluster__"])
print(len(clid))
del clid

In [None]:
merged_ds=load_dataset("parquet",data_files=f"{OUTPUT_PATH}_reindexed({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")["train"]
seencluster=set()
for _ in range(5):
    duplicated_id=-1
    for item in merged_ds:
        if(item["__cluster__"]!=item["oid"] and item["__cluster__"] not in seencluster):
            duplicated_id=item["__cluster__"]
            break
    for item in merged_ds.filter(lambda item: item["__cluster__"]==duplicated_id,num_proc=NUM_PROC):
        print(item)
    seencluster.add(duplicated_id)

In [None]:
merged_ds=load_dataset("parquet",data_files=f"{OUTPUT_PATH}_reindexed({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")["train"]
seencluster=set()
for _ in range(5):
    duplicated_id=-1
    for item in merged_ds:
        if(item["__cluster__"]!=item["oid"] and item["dataset"]=="sft_tvpl" and item["__cluster__"] not in seencluster):
            duplicated_id=item["__cluster__"]
            break
    for item in merged_ds.filter(lambda item: item["__cluster__"]==duplicated_id,num_proc=NUM_PROC):
        print(item)
    seencluster.add(duplicated_id)

In [None]:
clusterlen={}
for item in tqdm(merged_ds):
    cluster=item["__cluster__"]
    clusterlen[cluster]=clusterlen.get(cluster,0)+1
lendist={}
for k,v in tqdm(clusterlen.items()):
    lendist[v]=lendist.get(v,0)+1

In [None]:
print(lendist)

In [None]:
merged_ds=load_dataset("parquet",data_files=f"{OUTPUT_PATH}_reindexed({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")["train"]
try:
    merged_ds=merged_ds.remove_columns(["__score__","oparent"])
except:
    pass
merged_ds=merged_ds.sort("__cluster__")
precount=0
toKeep={}
# "zalo_legal_corpus"
# "tvpl_structured"
clusterMember={}
def isPreferedDataset(ds_label):
    return ds_label[0:3]=="sft" or ds_label=="zalo_legal_corpus"
for item in tqdm(merged_ds,desc="Getting oid to keep"):
    precount+=1
    pre=toKeep.get(item["__cluster__"],(-1,-1,""))
    clusterMember[item["__cluster__"]]=clusterMember.get(item["__cluster__"],[])+[item["oid"]]
    cur=(len(item["text"]),item["oid"],item["dataset"])
    if pre[0]<0:
        toKeep[item["__cluster__"]]=cur
    elif isPreferedDataset(cur[2]):
        if (not isPreferedDataset(pre[2])) or pre[0]<cur[0]:
            toKeep[item["__cluster__"]]=cur
    else:
        if (not isPreferedDataset(pre[2])) and pre[0]<cur[0]:
            toKeep[item["__cluster__"]]=cur
print(toKeep)
keep_oid=[tokeepitem[1] for tokeepitem in toKeep.values()]
keep_oid=set(keep_oid)
print(precount)
print(len(keep_oid))
import gc
gc.freeze()
gc.disable()
filtered=merged_ds.filter(
    function=lambda item: item["oid"] in keep_oid,
    num_proc=4,
    desc="Filtering clusters...",
).map(
    function= lambda item: {"__cluster_member__": clusterMember[item["__cluster__"]]},
    num_proc=4,
    desc="Adding cluster member IDs ..."
)  
gc.enable()
gc.collect()
pq.write_table(pa.Table.from_pylist(filtered),f"{OUTPUT_PATH}_filtered({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")

In [None]:
REMAP=f"({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})"
!mkdir -p "output/data_remapped_{REMAP}"
# DONOT USE the FILTERED DATASET HERE 
merged_ds=load_dataset("parquet",data_files=f"{OUTPUT_PATH}_reindexed({MINHASH_THRESH}|{JACCARD_THRESH}|{JACCARD_MODE})")["train"]
oid_cluster_dict={}
for item in tqdm(merged_ds):
    oid_cluster_dict[item["oid"]]=item["__cluster__"]

old_ds=load_dataset("parquet",data_files="output/metadata_map/tvpl_structured")["train"]
def tvpl_mapper(item,ocdict):
    children=item["data"]
    for id in range(len(children)):
        children[id]["__cluster__"]=ocdict.get(children[id]["oid"],children[id]["oid"])
    item["data"]=children
    return item
old_ds=old_ds.map(
    function=tvpl_mapper,
    fn_kwargs={
        "ocdict":oid_cluster_dict
    },
    num_proc=NUM_PROC,
    desc="Writing file tvpl_structured"
)
old_ds.to_parquet(f"output/data_remapped_{REMAP}/tvpl_dataset")

old_ds=load_dataset(path="parquet",data_files="output/metadata_map/sft_tvpl")["train"]
def sft_tpvl_mapper(item,ocdict):
    item["__context_cluster__"]=[]
    for oid in item["contextoid"]:
        item["__context_cluster__"].append(ocdict.get(oid,oid))
    return item
old_ds=old_ds.map(
    function=sft_tpvl_mapper,
    fn_kwargs={
        "ocdict":oid_cluster_dict
    },
    num_proc=NUM_PROC,
    desc="Writing file tvpl_structured"
)
old_ds.to_parquet(f"output/data_remapped_{REMAP}/sft_tvpl")


def data_mapper(item,ocdict):
    item["__cluster__"]=ocdict.get(item["oid"],item["oid"])
    return item
for file_name in ["zalo_legal_corpus"]:
    old_ds=datasets.load_dataset("parquet",data_files=f"output/metadata_map/{file_name}")["train"]
    old_ds=old_ds.map(
        function=data_mapper,
        fn_kwargs={
            "ocdict":oid_cluster_dict
        },
        num_proc=NUM_PROC,
        desc=f"Writing file {file_name}"     
    )
    old_ds.to_parquet(f"output/data_remapped_{REMAP}/{file_name}")

In [None]:
TARGET_IN=f"output/data_remapped_{REMAP}/sft_tvpl"
TARGET_OUT=f"output/data_remapped_{REMAP}/sft_tvpl_out"
! python -m minhash \
  --path "parquet" \
  --data_files "{TARGET_IN}"\
  --name "tvpl_sft" \
  --split "train" \
  --cache_dir "./cache" \
  --output "{TARGET_OUT}" \
  --column "question" \
  --threshold 0.4 \
  --min_length 5\
  --batch_size 10000 \
  --lsh_false_positive_thresh_weight 0.5 \
  --jaccard_mode "PREVIOUS" \
  --jaccard_thresh 0.52\
  --num_proc {NUM_PROC}\
  --cluster_column "__cluster_question__"\
  --index_column "questionoid"


In [None]:
sft=load_dataset("parquet",data_files=TARGET_OUT+"_reindexed(0.4|0.52|PREVIOUS)")["train"]
seencluster=set()
for _ in range(5):
    duplicated_id=-1
    for item in sft:
        if(item["__cluster_question__"]!=item["questionoid"] and  item["__cluster_question__"] not in seencluster):
            duplicated_id=item["__cluster_question__"]
            break
    for item in sft.filter(lambda item: item["__cluster_question__"]==duplicated_id,num_proc=NUM_PROC):
        print(item)
    seencluster.add(duplicated_id)

In [None]:
# sft=load_dataset(path="parquet",data_files=TARGET_OUT+"_reindexed(0.4|0.52|PREVIOUS)")["train"]
# CLUSTER_COL="__cluster_question__"
# sft=sft.map(
#     function=lambda x: {"_len_":len(x["text"])},
#     num_proc=NUM_PROC
# )
# sft=sft.sort([CLUSTER_COL,"_len_"])
# pre=sft[0][CLUSTER_COL]
# keeplist=[]
# for i,item in tqdm(enumerate(sft)):
#     if item[CLUSTER_COL]==pre:
#         continue
#     else:
#         keeplist.append(sft[i-1])
# new_sft=datasets.Dataset.from_list(keeplist)
# new_sft=new_sft.map(
#     remove_columns=["_len_","__cluster_question__"]
# )
# ns=new_sft.train_test_split(test_size=10000)
# OUTPUT_DIR_SFT=f"output/data_remapped_{REMAP}/tpvl_sft_dedup_and_resplit"
# os.makedirs(exist_ok=True,name=OUTPUT_DIR_SFT)
# ns["train"].to_parquet(os.path.join(OUTPUT_DIR_SFT,"train.parquet"))
# ns["test"].to_parquet(os.path.join(OUTPUT_DIR_SFT,"test.parquet"))

In [None]:
sft=load_dataset(path="parquet",data_files=TARGET_OUT+"_reindexed(0.4|0.52|PREVIOUS)")["train"]
CLUSTER_COL="__cluster_question__"
TEST_SIZE=10000
sortedsft=sft.sort("__cluster_question__")
split_index=len(sortedsft)-TEST_SIZE
precluster=sortedsft[split_index][CLUSTER_COL]
while split_index>0 and precluster==sortedsft[split_index-1][CLUSTER_COL]:
    split_index-=1
if split_index<=0:
    raise Exception()
OUTPUT_DIR_SFT=f"output/data_remapped_{REMAP}/tpvl_sft_dedup_and_resplit"
os.makedirs(exist_ok=True,name=OUTPUT_DIR_SFT)
def traingen():
    i=0
    while(i<split_index):
        yield sortedsft[i]
        i+=1
def testgen():
    i=split_index
    while(i<len(sortedsft)):
        yield sortedsft[i]
        i+=1
datasets.Dataset.from_generator(traingen,cache_dir="cache",num_proc=NUM_PROC).to_parquet(os.path.join(OUTPUT_DIR_SFT,"train.parquet"))
datasets.Dataset.from_generator(testgen,cache_dir="cache",num_proc=NUM_PROC).to_parquet(os.path.join(OUTPUT_DIR_SFT,"test.parquet"))

In [None]:
# ds:datasets.Dataset=datasets.load_dataset(path="parquet",data_files="output/data_remapped/tvpl_dataset")["train"]
# print(ds[0])

In [None]:
# tpvl_dataset=datasets.load_dataset("json",data_files="output/data_remapped/tvpl_structured")["train"]
# precount=0
# parentnochild_count=0
# toKeep={}
# for item in tpvl_dataset:
#     for child in item["child_data"]:
#         pre=toKeep.get(child["__cluster__"],(-1,-1))
#         if pre[0]<len(child["text"]):
#             toKeep[child["__cluster__"]]=(len(child["text"]),child["oid"])
# tpvl_filtered=[]
# for item in tpvl_dataset:
#     children=item["child_data"]
#     new_children=[]
#     for id in range(len(children)):
#         cid=children[id]["__cluster__"]
#         precount+=1
#         if children[id]["oid"]==toKeep[cid][1]:
#             new_children.append(children[id])
#     if len(new_children)==0:
#         parentnochild_count+=1
#         del item["child_data"]
#         print("NO CHILD:",json.dumps(item,ensure_ascii=False))
#     else:
#         item["child_data"]=new_children
#         tpvl_filtered.append(item)

# pq.write_table(pa.Table.from_pylist(tpvl_filtered),"output/data_remapped/tvpl_dataset_filtered")
# del tpvl_filtered
# print(precount)
# print(len(toKeep))
# print(parentnochild_count)