# Proof-of-Concept for Fine-Tuning Model with PEFT for Daily News

In [None]:
!pip install peft datasets
!pip install mistral_inference
#!pip install accelerate


In [11]:
import chromadb
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, pipeline
import torch
from peft import LoraConfig, get_peft_model
from datasets import Dataset
from tqdm import tqdm
from datetime import datetime
import spacy

In [12]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
#from accelerate import dispatch_model

# Load spaCy model for Named Entity Recognition (NER)
nlp = spacy.load("en_core_web_sm")

In [13]:
import sys
import os

project_root = os.path.abspath("..")  # Adjust if needed
sys.path.append(project_root)

from newsies.chromadb_client import ChromaDBClient, collections, get_all_headlines, find_ordinal
from newsies import targets


In [14]:
! mkdir -p ./training_data

In [15]:
# Step 1: Connect to ChromaDB and Retrieve Data
def fetch_news_data():
    client = ChromaDBClient()  # Update path
    client.collection_name=f"ap_news_{datetime.now().strftime(r'%Y-%m-%d')}"
    print(f"collection name: {client.collection.name}")
    collection = client.collection
    n  = collection.count()
    print(f"there are {n} stories in the collection")
    results = collection.get(where={"target":{"$eq":targets.DOCUMENT}}, limit=n)  
    return results["documents"], results["metadatas"]

news_docs, news_metadata = fetch_news_data()

collection name: ap_news_2025-03-12
there are 4142 stories in the collection


In [16]:
news_docs[0]

'NEW YORK (AP) — Harvey Weinstein ‘s #MeToo retrial next month will largely be an abridged version of the original, with one big addition: a charge based on an allegation from a woman who wasn’t a part of the first case.\nJust how the reprise of the disgraced movie mogul’s prosecution plays out is coming into focus at a hearing Wednesday, where a judge is set to issue rulings on a variety of issues, including the scope of accuser testimony and potential expert witnesses.\nWeinstein, 72, was in court for the hearing, which started more than a hour late after Judge Curtis Farber met with the prosecution and defense behind closed doors to discuss matters still under seal.\nThose included a prosecution request that two of the three accusers in the case be allowed to testify about other alleged encounters with Weinstein. They also discussed evidence of the accusers’ sexual history, which prosecutors say should be barred under New York’s Rape Shield Law.'

In [17]:
news_metadata[0]

{'chunk_index': 0,
 'collection': 'ap_news_2025-03-12',
 'date': '2025-03-12',
 'embedding_model': 'sentence-transformers/all-MiniLM-L6-v2',
 'headline0': 'Harvey Weinstein appears in court  as judge weighs key rulings for his looming #MeToo retrial',
 'headline1': 'Harvey Weinstein appears in court as judge weighs key rulings for his looming #MeToo retrial',
 'headline2': 'Harvey Weinstein due in court for key rulings as his #MeToo retrial nears',
 'section0': '',
 'section1': 'politics',
 'section2': 'technology',
 'target': 'DOCUMENT',
 'text': 'NEW YORK (AP) — Harvey Weinstein ‘s #MeToo retrial next month will largely be an abridged version of the original, with one big addition: a charge based on an allegation from a woman who wasn’t a part of the first case.\nJust how the reprise of the disgraced movie mogul’s prosecution plays out is coming into focus at a hearing Wednesday, where a judge is set to issue rulings on a variety of issues, including the scope of accuser testimony an

## Use Flan-T5-large to generate questions for each article and for the named entities in it

In [27]:
# Step 2: Generate Question-Answer Pairs using an LLM
qa_generator = pipeline("text2text-generation", model="google/flan-t5-large", device=0 if torch.cuda.is_available() else -1)

def extract_named_entities(text):
    doc = nlp(text)
    entities = list(set(ent.text for ent in doc.ents if ent.label_ in {"PERSON", "ORG", "GPE"}))
    return entities

Device set to use cuda:0


In [28]:

def save_qa_to_parquet(qa_data, file_path):
    df = pd.DataFrame(qa_data)
    df.to_parquet(file_path, index=False)

def load_qa_from_parquet(file_path):
    df = pd.read_parquet(file_path)
    return df.to_dict(orient="records")

In [31]:
def generate_qa_pairs(news_docs, news_metadata, batch_size=1000, entity_batch_size=1000):
    qa_data = []
    total_batches = (len(news_docs) + batch_size - 1) // batch_size

    for batch_start in tqdm(range(0, len(news_docs), batch_size), desc="Processing Article Batches"):
        batch_docs = news_docs[batch_start:batch_start + batch_size]
        batch_meta = news_metadata[batch_start:batch_start + batch_size]

        question_prompts = []
        entity_prompts = []
        entity_mapping = []  # Keep track of which entity question belongs to which article

        for doc, meta in zip(batch_docs, batch_meta):
            context = f"{meta['section0'] or 'front-page'}: {doc}"
            if meta["section1"] != "N/A":
                context += f"\n{meta['section1']}: {doc}"
            if meta["section2"] != "N/A":
                context += f"\n{meta['section2']}: {doc}"

            # Extract named entities
            entities = extract_named_entities(doc)

            # Generate 3 diverse questions about the article
            question_prompts.append(
                f"For the following question, return the section, headline, and URI: Generate 3 different questions about the following news article. "
                f"Include questions that focus on key details, impacts, and reasons. "
                f"News: {context}"
            )

            # Generate questions for each entity separately
            for entity in entities:
                entity_prompts.append(
                    f"For the following question, return the section, headline, and URI: Generate a question about {entity} in relation to the following news article. "
                    f"News: {context}"
                )
                entity_mapping.append((doc, meta))  # Track which article each entity belongs to

        print(datetime.now(), f"Processing article batch {batch_start // batch_size + 1}/{total_batches}")

        # Generate article-level questions
        article_questions = qa_generator(question_prompts, max_length=50, truncation=True)

        # Store results for articles
        for (doc, meta), article_question_output in zip(zip(batch_docs, batch_meta), article_questions):
            questions = article_question_output["generated_text"].split("\n")
            qa_data.append({
                "questions": questions,
                "context": doc,
                "answer": [{"headline": meta["headline0"], "uri": meta["uri"]}]
            })

        # Process entity-related questions in batches
        total_entity_batches = (len(entity_prompts) + batch_size - 1) // batch_size        
        for entity_batch_start in tqdm(range(0, len(entity_prompts), entity_batch_size), desc="Processing Entity Batches"):
            entity_batch = entity_prompts[entity_batch_start:entity_batch_start + entity_batch_size]
            
            print(datetime.now(), f"Processing named entity batch {entity_batch_start // batch_size + 1}/{total_entity_batches}")            
            entity_results = qa_generator(entity_batch, max_length=50, truncation=True)

            for (doc, meta), entity_question_output in zip(entity_mapping[entity_batch_start:entity_batch_start + entity_batch_size], entity_results):
                qa_data.append({
                    "questions": [entity_question_output["generated_text"]],
                    "context": doc,
                    "answer": [{"headline": meta["headline0"], "uri": meta["uri"]}]
                })

        # Save batch results and clear memory        
        batch_file = f"qa_dataset_batch_{batch_start // batch_size + 1}.parquet"
        save_qa_to_parquet(qa_data, batch_file)
        qa_data.clear()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(datetime.now(), "All batches processed")


### Generate the training data

In [32]:
generate_qa_pairs(news_docs, news_metadata)

Processing Article Batches:   0%|                                                                                | 0/3 [00:00<?, ?it/s]

2025-03-13 07:21:39.571169 Processing article batch 1/3



[Acessing Entity Batches:   0%|                                                                                 | 0/7 [00:00<?, ?it/s]

2025-03-13 07:24:20.221756 Processing named entity batch 1/7



[Acessing Entity Batches:  14%|██████████▎                                                             | 1/7 [02:40<16:03, 160.56s/it]

2025-03-13 07:27:00.784498 Processing named entity batch 2/7



[Acessing Entity Batches:  29%|████████████████████▌                                                   | 2/7 [05:22<13:27, 161.41s/it]

2025-03-13 07:29:42.795059 Processing named entity batch 3/7



[Acessing Entity Batches:  43%|██████████████████████████████▊                                         | 3/7 [08:05<10:47, 161.96s/it]

2025-03-13 07:32:25.411027 Processing named entity batch 4/7



[Acessing Entity Batches:  57%|█████████████████████████████████████████▏                              | 4/7 [10:42<07:59, 160.00s/it]

2025-03-13 07:35:02.392411 Processing named entity batch 5/7



[Acessing Entity Batches:  71%|███████████████████████████████████████████████████▍                    | 5/7 [13:32<05:27, 163.58s/it]

2025-03-13 07:37:52.326540 Processing named entity batch 6/7



[Acessing Entity Batches:  86%|█████████████████████████████████████████████████████████████▋          | 6/7 [16:13<02:42, 162.95s/it]

2025-03-13 07:40:34.045242 Processing named entity batch 7/7



Processing Entity Batches: 100%|████████████████████████████████████████████████████████████████████████| 7/7 [18:52<00:00, 161.85s/it]
Processing Article Batches:  33%|███████████████████████▎                                              | 1/3 [21:48<43:37, 1308.94s/it]

2025-03-13 07:43:27.510871 Processing article batch 2/3



[Acessing Entity Batches:   0%|                                                                                 | 0/7 [00:00<?, ?it/s]

2025-03-13 07:46:09.182424 Processing named entity batch 1/7



[Acessing Entity Batches:  14%|██████████▎                                                             | 1/7 [02:54<17:26, 174.35s/it]

2025-03-13 07:49:03.533733 Processing named entity batch 2/7



[Acessing Entity Batches:  29%|████████████████████▌                                                   | 2/7 [05:45<14:20, 172.19s/it]

2025-03-13 07:51:54.208713 Processing named entity batch 3/7



[Acessing Entity Batches:  43%|██████████████████████████████▊                                         | 3/7 [08:37<11:29, 172.46s/it]

2025-03-13 07:54:46.999430 Processing named entity batch 4/7



[Acessing Entity Batches:  57%|█████████████████████████████████████████▏                              | 4/7 [11:20<08:25, 168.43s/it]

2025-03-13 07:57:29.245051 Processing named entity batch 5/7



[Acessing Entity Batches:  71%|███████████████████████████████████████████████████▍                    | 5/7 [14:03<05:33, 166.52s/it]

2025-03-13 08:00:12.370961 Processing named entity batch 6/7



[Acessing Entity Batches:  86%|█████████████████████████████████████████████████████████████▋          | 6/7 [16:52<02:47, 167.60s/it]

2025-03-13 08:03:02.064344 Processing named entity batch 7/7



Processing Entity Batches: 100%|████████████████████████████████████████████████████████████████████████| 7/7 [17:46<00:00, 152.39s/it]
Processing Article Batches:  67%|██████████████████████████████████████████████▋                       | 2/3 [42:31<21:09, 1269.89s/it]

2025-03-13 08:04:09.949738 Processing article batch 3/3



[Acessing Entity Batches:   0%|                                                                                 | 0/6 [00:00<?, ?it/s]

2025-03-13 08:06:45.214458 Processing named entity batch 1/6



[Acessing Entity Batches:  17%|████████████                                                            | 1/6 [02:41<13:28, 161.64s/it]

2025-03-13 08:09:26.859152 Processing named entity batch 2/6



[Acessing Entity Batches:  33%|████████████████████████                                                | 2/6 [05:30<11:03, 165.93s/it]

2025-03-13 08:12:15.793945 Processing named entity batch 3/6



[Acessing Entity Batches:  50%|████████████████████████████████████                                    | 3/6 [08:12<08:12, 164.31s/it]

2025-03-13 08:14:58.162742 Processing named entity batch 4/6



[Acessing Entity Batches:  67%|████████████████████████████████████████████████                        | 4/6 [11:05<05:35, 167.55s/it]

2025-03-13 08:17:50.698864 Processing named entity batch 5/6



[Acessing Entity Batches:  83%|████████████████████████████████████████████████████████████            | 5/6 [13:50<02:46, 166.65s/it]

2025-03-13 08:20:35.744719 Processing named entity batch 6/6



Processing Entity Batches: 100%|████████████████████████████████████████████████████████████████████████| 6/6 [15:54<00:00, 159.10s/it]
Processing Article Batches: 100%|████████████████████████████████████████████████████████████████████| 3/3 [1:01:15<00:00, 1225.11s/it]

2025-03-13 08:22:39.813166 All batches processed





In [33]:

training_data=pd.read_parquet(project_root+"/notebooks/training_data")

In [34]:
training_data[training_data["question"].isnull()].count()

question        0
context     22023
answer      22023
dtype: int64

In [35]:
training_data[training_data["question"].notna()].count()

question    3023
context     3023
answer      3023
dtype: int64

In [41]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)

print(training_data[training_data["question"].notna()]["question"].head(100))

0                                                                    What is the main reason for the rise in food poisoning in November and December?
1                           What is the first time in the past 19 years that there has been no team entering the March Madness with zero or one loss?
2                                                                                                                      What team did Palmer play for?
3                                                                                    What is the name of the planet that is visible to the naked eye?
4                                                                                         What is the name of the hospital where Richard Webby works?
5                                  What is the estimated cost of reconstruction and recovery for Lebanon following the 14-month Israel-Hezbollah war?
6                                                                            What is the name of the

In [44]:
print(os.listdir(project_root+"/notebooks"))

['qa_dataset_batch_2.parquet', 'newsies_mistral_peft.ipynb', 'ngrams.ipynb', 'qa_dataset.parquet', 'story_named_entity_embedding.ipynb', 'qa_dataset_batch_1.parquet', '.ipynb_checkpoints', 'qa_dataset_batch_3.parquet', 'news_finetune_model']


## Remove the Flan-T5 model from GPU

In [None]:
del qa_generator
torch.cuda.empty_cache()

In [10]:
from huggingface_hub import snapshot_download
from pathlib import Path

mistral_models_path = Path.home().joinpath('mistral_models', '7B-v0.3')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/Mistral-7B-v0.3", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

params.json:   0%|          | 0.00/202 [00:00<?, ?B/s]

tokenizer.model.v3:   0%|          | 0.00/587k [00:00<?, ?B/s]

consolidated.safetensors:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

'/home/mpeters/mistral_models/7B-v0.3'

In [24]:
# Step 4: Prepare Data for Fine-Tuning
def format_dataset(qa_dataset):
    dataset = Dataset.from_pandas(pd.DataFrame([{ "input_text": item["question"], "output_text": str(item["answer"]) } for item in qa_dataset]))
    return dataset.train_test_split(test_size=0.2)

split_dataset = format_dataset(qa_dataset)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]


In [33]:
# Step 5: Load Model and Apply LoRA Fine-Tuning
base_model_name = "mistralai/Mistral-7B-v0.3"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# LoRA Configuration
lora_config = LoraConfig(
    r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
)
model = get_peft_model(model, lora_config)

# Training Arguments
training_args = TrainingArguments(
    output_dir="./news_finetune_model",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    fp16=True,
    optim="adamw_torch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
# Step 5: Evaluate the Fine-Tuned Model
def evaluate_model(sample_question):
    inputs = tokenizer(sample_question, return_tensors="pt").to("cuda")
    output = model.generate(**inputs, max_new_tokens=50)
    return tokenizer.decode(output[0], skip_special_tokens=True)

sample_question = qa_dataset[0]["question"]
response = evaluate_model(sample_question)
print(f"Q: {sample_question}\nA: {response}")
