In [19]:
from utils import split_data, add_subcomments, create_comments_list
from tqdm import tqdm
from relation_classifier import relation_classifier
from model_summ import T5summarizer
from textrank import preprocess_comment
import json

import warnings
warnings.filterwarnings("ignore")

In [2]:
model = T5summarizer("gromoboy/rut5_base_summ_brand")

In [10]:
# PARAMETERS
dataset_input = "dataset.jsonl"
dataset_output = "result.jsonl"
split_type = "all_comments"

In [4]:
# GLOBAL
BATCH_SIZE = 4 # model batch size
CHUNK_SIZE = 30 # amount of sentences which we send for prediction

In [5]:
jsonl_file_path = f'data/{dataset_input}'
posts, comments = split_data(jsonl_file_path)

comments_by_root = {}
for comment in tqdm(comments, desc="Root ID sort comments"):
    root_id = comment["root_id"]

    if root_id in comments_by_root:
        comments_by_root[root_id].append(comment)
    else:
        comments_by_root[root_id] = [comment]


data = []
for index, post in enumerate(tqdm(posts, desc="Gathering preprocessed posts/comments together")):

    comments_structured = add_subcomments(comments_by_root[post["id"]])
    comments_unstructured = create_comments_list(comments_structured)

    data.append([(post['hash'], post['text']),
                 comments_unstructured])

Reading data...


443898it [00:31, 14043.07it/s]
Root ID sort comments: 100%|██████████| 433890/433890 [00:00<00:00, 1868060.66it/s]
Gathering preprocessed posts/comments together: 100%|██████████| 10008/10008 [00:01<00:00, 9291.75it/s]


In [17]:
result = []
for post_data in data[2020:2040]:
    post, comments = post_data

    post_hash = post[0]

    if split_type == "post_comments":
        comments = relation_classifier(post, comments, "direct")
    
    if split_type == "topic_comments":
        comments = relation_classifier(post, comments, "indirect")

    comments_texts, comments_hash = [], []
    for comment in comments:
        hash_, text = comment
        comments_texts.append(preprocess_comment(text))
        comments_hash.append(hash_)
    
    chunks = []
    for i in range(0, len(comments_texts), CHUNK_SIZE):
        chunks.append("\n".join(comments_texts[i : i+CHUNK_SIZE]))
    
    #print("==", len(chunks))
    if len(chunks) != 0:
        summary = model.batch_summarize(chunks, BATCH_SIZE)
    else:
        summary = "Отсутствует содержание"
    
    result.append({"summary": summary,
                    "post_hash": post_hash,
                    "comments_hash": comments_hash})

  0%|          | 0/1 [00:00<?, ?it/s]Your max_length is set to 200, but your input_length is only 109. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=54)
100%|██████████| 1/1 [00:01<00:00,  1.44s/it]
  0%|          | 0/1 [00:00<?, ?it/s]Your max_length is set to 200, but your input_length is only 161. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=80)
100%|██████████| 1/1 [00:03<00:00,  3.38s/it]
  0%|          | 0/1 [00:00<?, ?it/s]Your max_length is set to 200, but your input_length is only 132. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=66)
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]
 

In [18]:
filename = f"data/{dataset_output}"

# Open the file in write mode
with open(filename, 'w', encoding='utf-8') as file:
    for entry in result:
        # Convert the dictionary to a JSON string and write to the file
        json_record = json.dumps(entry, ensure_ascii=False)
        file.write(json_record + '\n')

In [17]:
len(comments_texts)

28

In [57]:
len(relation_classifier(post, comments, "indirect"))

4

In [38]:
comments_hash = [comment[0] for comment in comments]

In [42]:
len(comments_hash)

32

In [52]:
len(relation_classifier(post, comments, "indirect"))

4