In [None]:
from time import sleep
from supabase import create_client, Client
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from tqdm import tqdm

# Load Supabase credentials from a separate file
with open("supabase_credentials.txt", "r") as f:
    url = f.readline().strip()
    key = f.readline().strip()

# Create a Supabase client
supabase: Client = create_client(url, key)

# Check if CUDA is available and set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load tokenizer and model from the local directory
# Path to the model on hagging face: https://huggingface.co/sshleifer/distilbart-cnn-12-6
model_path = os.path.join(os.getcwd(), os.pardir,os.pardir, "models/distilbart-cnn-12-6/")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.to(device)

nlp = pipeline("summarization", model=model, tokenizer=tokenizer, device=device, use_fast=True)
limit = 100
offset = 38500

# Fetch count of rows in the table
response = supabase.table("news").select("id", count="exact").execute()
total_rows = response.count
print(f"Total rows to process: {total_rows}")
tqdm_bar = tqdm(range(offset, total_rows, limit), desc="Processing Batches")
for offset in tqdm_bar:
    response = supabase.table("news").select("id, content, summarization").order("id").range(offset, offset + limit - 1).execute()
    rows = response.data
    if not rows:
        break
    filtered_rows = list(filter(lambda x: len(x['content']) > 512 and x['summarization'] is None, rows))
    if filtered_rows:
        texts = [row['content'] for row in filtered_rows]
        ids = [row['id'] for row in filtered_rows]
        tqdm_bar.set_postfix({'curr_id': str(rows[0]['id'])})
        summaries = nlp(texts, max_length=105, min_length=86, truncation=True, length_penalty=2.0, num_beams=4, early_stopping=True)
        updates = [{"id": id, "summarization": summary['summary_text']} for id, summary in zip(ids, summaries)]
        supabase.table("news").upsert(updates, on_conflict=["id"]).execute()