In [1]:
import torch
from transformers import pipeline

# 1. Initialize the pipeline
summarizer = pipeline(
    task="text-generation",
    model="google/gemma-3-1b-it",
    # model="models/gemma-3-1b-sft/merged/",
    device=0,
    torch_dtype=torch.bfloat16
)

  from .autonotebook import tqdm as notebook_tqdm
2026-01-26 08:13:29.503236: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Device set to use cuda:0


In [2]:
import json

input_file = "dataset/org_with_cefr_labels/us_test.jsonl"
dataset = [json.loads(line.strip()) for line in open(input_file)]
print(dataset[0]['text'])
dataset = dataset[:100]

SECTION 1. TRANSFERS OF MOTORBOAT FUEL TAXES FROM HIGHWAY TRUST FUND.

    (a) Authorization of Transfers.--Section 9503(c)(4) of the Internal 
Revenue Code of 1986 (26 U.S.C. 9503(c)(4)) is amended--
            (1) by striking subparagraph (A) of section 9503(c)(4);
            (2) by redesignating subparagraph (B) as subparagraph (A) 
        and amending it to read as follows:
                    ``(A) $1,000,000 per year transferred to land and 
                water conservation fund.--
                            ``(i) In general.--The Secretary shall pay 
                        from time to time from the Highway Trust Fund 
                        into the land and water conservation fund 
                        provided for in title I of the Land and Water 
                        Conservation Fund Act of 1965 amounts (as 
                        determined by him) equivalent to the motorboat 
                        fuel taxes received on or after October 1, 
              

In [3]:
import json
from jinja2 import Template

SYSTEM_PROMPT = "You are helpful assistant designed to make English legal text more readable for different target audience at different CEFR readability levels."

PROMPT_TEMPLATE = "Summarize the following text for a {{ level }} reader. {{ text }} \n Please output the summary as a paragraph."

jinja_template = Template(PROMPT_TEMPLATE)
print(jinja_template)

<Template memory:7891b71416f0>


In [4]:
# from transformers import pipeline
# from tqdm import tqdm

# cefr_labels = ["A1", "A2", "B1", "B2", "C1", "C2"]

# for i in tqdm(range(len(dataset))):
    
#     text = dataset[i]["text"]
#     messages = [
#             [
#                 {
#                     "role": "system",
#                     "content": [{"type": "text", "text": prompt_json["system_prompt"]},]
#                 },
#                 {
#                     "role": "user",
#                     "content": [{"type": "text", "text": jinja_template.render(text=text, level=label)},]
#                 },
#             ]
#         for label in cefr_labels
#     ]

#     summaries = summarizer(messages, max_new_tokens=512, batch_size=6)

#     all_summaries = {}
#     for summary, label in zip(summaries, cefr_labels):
#         summary_text = summary[-1]["generated_text"][-1]["content"]

#         summary_text = summary_text.split("\n\n")
#         summary_text = "\n\n".join(summary_text[1:])

#         all_summaries[label] = summary_text

#     dataset[i]["predictions"] = all_summaries


In [4]:
from transformers.pipelines.pt_utils import KeyDataset
from tqdm import tqdm
import numpy as np
import networkx as nx
import networkx.utils

cefr_labels = ["A1", "A2", "B1", "B2", "C1", "C2"]

if not hasattr(np, 'int'):
    np.int = int
    
# 1. Flatten the dataset into "Task Rows"
# Each row in 'flat_data' will represent ONE summary task
flat_data = []
for item in dataset:
    for label in cefr_labels:
        # Pre-render the messages
        msg = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": jinja_template.render(text=item["text"], level=label)}
        ]
        flat_data.append({"messages": msg})

# 2. Use a Generator to stream data (prevents the 'sequential' warning)
def data_generator():
    for item in flat_data:
        # print(item)
        yield item["messages"]

# 3. Run the pipeline (Properly configured)
# 'batch_size' goes here, NOT in model_kwargs
summaries = []
for out in tqdm(summarizer(data_generator(), batch_size=32, max_new_tokens=512), total=len(flat_data)):
    # Extract the text content from the pipeline output
    summary_text = out[-1]["generated_text"][-1]["content"]
    
    # Cleaning logic (removing the preamble if exists)
    # parts = summary_text.split("\n\n")
    # cleaned = "\n\n".join(parts[1:]) if len(parts) > 1 else summary_text
    summaries.append(summary_text)

# 4. Re-shape the results back into your original 18,948 dataset rows
for i in range(len(dataset)):
    start_idx = i * 6
    end_idx = start_idx + 6
    # Map the 6 sequential summaries back to their CEFR labels
    dataset[i]["predictions"] = dict(zip(cefr_labels, summaries[start_idx:end_idx]))

100%|██████████| 600/600 [04:58<00:00,  2.01it/s]


In [5]:
print(dataset[0])

{'bill_id': '103_s2052', 'text': "SECTION 1. TRANSFERS OF MOTORBOAT FUEL TAXES FROM HIGHWAY TRUST FUND.\n\n    (a) Authorization of Transfers.--Section 9503(c)(4) of the Internal \nRevenue Code of 1986 (26 U.S.C. 9503(c)(4)) is amended--\n            (1) by striking subparagraph (A) of section 9503(c)(4);\n            (2) by redesignating subparagraph (B) as subparagraph (A) \n        and amending it to read as follows:\n                    ``(A) $1,000,000 per year transferred to land and \n                water conservation fund.--\n                            ``(i) In general.--The Secretary shall pay \n                        from time to time from the Highway Trust Fund \n                        into the land and water conservation fund \n                        provided for in title I of the Land and Water \n                        Conservation Fund Act of 1965 amounts (as \n                        determined by him) equivalent to the motorboat \n                        fuel taxe

In [6]:
from collections import Counter

output_file = "dataset/predictions/oob/us_test.jsonl"

with open(output_file, "w") as fp:
    for instance in dataset:
        fp.write(json.dumps(instance).strip() + "\n")

    