In [None]:
from typing import List

from punctuators.models import PunctCapSegModelONNX

# Instantiate this model
# This will download the ONNX and SPE models. To clean up, delete this model from your HF cache directory.
m = PunctCapSegModelONNX.from_pretrained("pcs_en")

# Define some input texts to punctuate
input_texts: List[str] = [
    # Literally my weekend
    "i woke up at 6 am and took the dog for a hike in the metacomet mountains we like to take morning adventures on the weekends",
    "despite being mid march it snowed overnight and into the morning here in connecticut it was snowier up in the mountains than in the farmington valley where i live",
    "when i got home i trained this model on the lambda cloud on an a100 gpu with about 10 million lines of text the total budget was less than 5 dollars",
    # Real acronyms in sentences that I made up
    "george hw bush was the president of the us for 8 years",
    "i saw mr smith at the store he was shopping for a new lawn mower i suggested he get one of those new battery operated ones they're so much quieter",
    # See how the model performs on made-up acronyms 
    "i went to the fgw store and bought a new tg optical scope",
    # First few sentences from today's featured article summary on wikipedia
    "it's that man again itma was a radio comedy programme that was broadcast by the bbc for twelve series from 1939 to 1949 featuring tommy handley in the central role itma was a character driven comedy whose satirical targets included officialdom and the proliferation of minor wartime regulations parts of the scripts were rewritten in the hours before the broadcast to ensure topicality"
]
results: List[List[str]] = m.infer(input_texts)
for input_text, output_texts in zip(input_texts, results):
    print(f"Input: {input_text}")
    print(f"Outputs:")
    for text in output_texts:
        print(f"\t{text}")
    print()


Input: i woke up at 6 am and took the dog for a hike in the metacomet mountains we like to take morning adventures on the weekends
Outputs:
	I woke up at 6 a.m. and took the dog for a hike in the Metacomet Mountains.
	We like to take morning adventures on the weekends.

Input: despite being mid march it snowed overnight and into the morning here in connecticut it was snowier up in the mountains than in the farmington valley where i live
Outputs:
	Despite being mid March, it snowed overnight and into the morning.
	Here in Connecticut, it was snowier up in the mountains than in the Farmington Valley where I live.

Input: when i got home i trained this model on the lambda cloud on an a100 gpu with about 10 million lines of text the total budget was less than 5 dollars
Outputs:
	When I got home, I trained this model on the Lambda Cloud.
	On an A100 GPU with about 10 million lines of text, the total budget was less than 5 dollars.

Input: george hw bush was the president of the us for 8 yea

In [19]:
import pandas as pd
transcripts = pd.read_pickle("/data2/brain2text/b2t_25/transcripts_val_cleaned.pkl")

In [None]:
import torch
import time

# Get all sentences from transcripts
sentences = transcripts.tolist() if hasattr(transcripts, 'tolist') else list(transcripts)
print(f"Total sentences: {len(sentences)}")

# Check initial VRAM
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()
    initial_mem = torch.cuda.memory_allocated() / 1024**3
    print(f"Initial VRAM: {initial_mem:.2f} GB")

# Process in batches of 500
batch_size = 500
all_results = []

start_time = time.time()

for i in range(0, len(sentences), batch_size):
    batch = sentences[i:i+batch_size]
    batch_results = m.infer(batch)
    all_results.extend(batch_results)
    
    # Print progress
    processed = min(i + batch_size, len(sentences))
    elapsed = time.time() - start_time
    rate = processed / elapsed
    print(f"Processed {processed}/{len(sentences)} sentences ({rate:.1f} sent/sec)")
    
    if torch.cuda.is_available():
        current_mem = torch.cuda.memory_allocated() / 1024**3
        peak_mem = torch.cuda.max_memory_allocated() / 1024**3
        print(f"  Current VRAM: {current_mem:.2f} GB, Peak VRAM: {peak_mem:.2f} GB")

total_time = time.time() - start_time
print(f"\n{'='*50}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Average speed: {len(sentences)/total_time:.1f} sentences/sec")
if torch.cuda.is_available():
    print(f"Peak VRAM used: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")