In [1]:
import os
import time
import torch
import csv
from datasets import load_dataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
ratios = [0.5, 0.6, 0.7]
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = "multifieldqa_en"
dataset = load_dataset("THUDM/LongBench", dataset_name, split="test", cache_dir="custom_cache_dir", trust_remote_code=True)

In [3]:
summ_model_name = 'facebook/bart-large-cnn'
summ_model = BartForConditionalGeneration.from_pretrained(summ_model_name).to(device)
summ_tokenizer = BartTokenizer.from_pretrained(summ_model_name)

In [4]:
def summarize_text(input_text, ratio):
    inputs = summ_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)
    input_token_length = inputs['input_ids'].shape[1]
    target_length = int(input_token_length * ratio)
    summary_ids = summ_model.generate(
        inputs['input_ids'], 
        max_length=min(target_length + 100, 1024),  # Maximum summary length based on compression ratio
        min_length=target_length,  # Optional: set a minimum length to avoid very short summaries
        # length_penalty=2.0,  # Optional: tweak the length penalty to get more compact summaries
        num_beams=4,  # Optional: use beam search for better quality
        
        early_stopping=True  # Stops early if the beam has converged
    )
    # print(input_token_length, len(summary_ids[0]))
    summary = summ_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

def generate_multiple_summaries(input_text, ratio):
  idx = 0
  input_len = len(input_text)
  summary = ""
  while idx < input_len:
    end_idx = min(idx+4000, input_len)
    summary += " " + summarize_text(input_text[idx:end_idx], ratio)
    idx = end_idx + 1
  return summary

def write_dicts_to_csv(data, filename):
    file_exists = os.path.isfile(filename)
    fieldnames = list(data[0].keys())
    # Open the file in append mode if it exists, otherwise write mode
    with open(filename, mode='a' if file_exists else 'w', newline='', encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)

        # Write the header only if the file is being created
        if not file_exists:
            writer.writeheader()

        # Write the rows
        writer.writerows(data)

In [5]:
def generate_dataset_summaries(ratio, n, start = 0):
    all_data = [None] * n
    filename = f"summarize_{int(ratio * 100)}.csv"
    
    for i in range(start, start + n):
        item = dataset[i].copy()
        start_time = time.time()
        try:
            summ_context = generate_multiple_summaries(item["context"], ratio)
            item["summary"] = summ_context
            all_data[i - start] = item
        except Exception as e:
            item["summary"] = "Failed to generate summary"
            print("Failed to generate summary :", e)
            all_data[i - start] = item
        print(f"Step {i}: {time.time()-start_time}")
    
    write_dicts_to_csv(all_data, filename)
    return all_data

In [7]:
all_data_70 = generate_dataset_summaries(ratio=0.7, n=10, start=0)

Step 0: 21.386173963546753
Step 1: 180.09903645515442
Step 2: 133.42056441307068
Step 3: 201.87702631950378
Step 4: 118.67760038375854
Step 5: 162.7499794960022
Step 6: 124.14846968650818
Step 7: 44.573601961135864
Step 8: 22.610684871673584
Step 9: 194.89532351493835


In [12]:
all_data_70_pt2 = generate_dataset_summaries(ratio=0.7, n=20, start=10)

Step 10: 113.77387475967407
Step 11: 160.10191774368286
Step 12: 194.67646169662476
Step 13: 122.11745715141296
Step 14: 150.01493501663208
Step 15: 40.058085918426514
Step 16: 131.24954986572266
Step 17: 145.52287673950195
Step 18: 156.06720209121704
Step 19: 61.887725591659546
Step 20: 32.349858045578
Step 21: 117.89512300491333
Step 22: 88.87893509864807
Step 23: 82.55633282661438
Step 24: 183.78623151779175
Step 25: 100.80478835105896
Step 26: 151.42920184135437
Step 27: 229.25206351280212
Step 28: 122.76475739479065
Step 29: 141.39921927452087


In [19]:
all_data_70_pt3 = generate_dataset_summaries(ratio=0.7, n=20, start=31)

Step 31: 86.04330492019653
Step 32: 120.18569111824036
Step 33: 89.70646524429321
Step 34: 109.9111316204071
Step 35: 85.49796867370605
Step 36: 31.43401861190796
Step 37: 87.48950815200806
Step 38: 47.55823850631714
Step 39: 39.38619804382324
Step 40: 41.11901021003723
Step 41: 21.376736164093018
Step 42: 55.44414305686951
Step 43: 102.46521282196045
Step 44: 79.87469625473022
Step 45: 93.14845943450928
Step 46: 65.12417578697205
Step 47: 77.8309257030487
Step 48: 29.090771198272705
Step 49: 22.946280479431152
Step 50: 27.28671884536743


In [21]:
all_data_50 = generate_dataset_summaries(ratio=0.5, n=20, start=0)

Step 0: 9.92168378829956
Step 1: 82.9163761138916
Step 2: 63.05988359451294
Step 3: 92.89555954933167
Step 4: 55.51726841926575
Step 5: 72.20669102668762
Step 6: 58.414210081100464
Step 7: 20.26397132873535
Step 8: 10.262131452560425
Step 9: 84.88992023468018
Step 10: 68.33250164985657
Step 11: 74.01338791847229
Step 12: 85.16296243667603
Step 13: 58.198068141937256
Step 14: 69.83821105957031
Step 15: 19.995205879211426
Step 16: 60.45523977279663
Step 17: 65.05537915229797
Step 18: 71.2682557106018
Step 19: 27.353734970092773


In [6]:
all_data_50_pt2 = generate_dataset_summaries(ratio=0.5, n=30, start=21)

Step 21: 53.4440541267395
Step 22: 40.368181467056274
Step 23: 37.19951367378235
Step 24: 80.91890096664429
Step 25: 46.82876968383789
Step 26: 65.73007082939148
Step 27: 100.17182731628418
Step 28: 54.463985443115234
Step 29: 65.86658453941345
Step 30: 29.05880856513977
Step 31: 38.95351052284241
Step 32: 55.36174559593201
Step 33: 40.17489957809448
Step 34: 51.12066864967346
Step 35: 37.7234148979187
Step 36: 13.751059532165527
Step 37: 52.18760824203491
Step 38: 32.754300355911255
Step 39: 26.962358236312866
Step 40: 27.517829179763794
Step 41: 14.517345666885376
Step 42: 37.877095222473145
Step 43: 71.99322628974915
Step 44: 56.82924556732178
Step 45: 65.41368222236633
Step 46: 57.69158315658569
Step 47: 64.61412501335144
Step 48: 25.192261219024658
Step 49: 21.414840936660767
Step 50: 24.731569528579712
