In [17]:
from enum import StrEnum
import json

from datasets import load_dataset
from openai import OpenAI
import pandas as pd
from pydantic import BaseModel, Field, ValidationError


In [2]:

dataset = load_dataset("SetFit/bbc-news")

train_df = pd.DataFrame(dataset['train'])
test_df = pd.DataFrame(dataset['test'])
combined_df = pd.concat([train_df, test_df])

print("Dataset Overview:")
print(f"Dataset size: {len(combined_df)}")
print("\nLabel distribution in training set:")
print(combined_df['label_text'].value_counts())
print("\nSample text from training set:")
print(combined_df['text'].iloc[0][:200])

train_df['text_length'] = train_df['text'].str.len()
print(f"\nText length mean: {train_df['text_length'].mean():.2f}")

Dataset Overview:
Dataset size: 2225

Label distribution in training set:
label_text
sport            511
business         510
politics         417
tech             401
entertainment    386
Name: count, dtype: int64

Sample text from training set:
wales want rugby league training wales could follow england s lead by training with a rugby league club.  england have already had a three-day session with leeds rhinos  and wales are thought to be in

Text length mean: 2288.95


In [3]:
class ArticleCategory(StrEnum):
    BUSINESS = "business"
    ENTERTAINMENT = "entertainment"
    POLITICS = "politics"
    SPORT = "sport"
    TECH = "tech"

class NamedEntityType(StrEnum):
    COMPANY = "company"
    COUNTRY = "country"
    LOCATION = "location"
    PERSON = "person"

class NamedEntity(BaseModel):
    name: str = Field(description="Name of the person, company, etc.")
    type: NamedEntityType = Field(description="Type of named entity")

class NewsArticle(BaseModel):
    title: str = Field(description="Appropriate headline for the news article")
    category: ArticleCategory = Field(description="Category of news that the article belongs to")
    mentioned_entities: list[NamedEntity] = Field(description="List of all named entities in the article", default_factory=list)
    summary: str = Field(description="Brief summary of the article content.")



In [4]:
news_article_schema = NewsArticle.model_json_schema()
print(json.dumps(news_article_schema, indent=2))

{
  "$defs": {
    "ArticleCategory": {
      "enum": [
        "business",
        "entertainment",
        "politics",
        "sport",
        "tech"
      ],
      "title": "ArticleCategory",
      "type": "string"
    },
    "NamedEntity": {
      "properties": {
        "name": {
          "description": "Name of the person, company, etc.",
          "title": "Name",
          "type": "string"
        },
        "type": {
          "$ref": "#/$defs/NamedEntityType",
          "description": "Type of named entity"
        }
      },
      "required": [
        "name",
        "type"
      ],
      "title": "NamedEntity",
      "type": "object"
    },
    "NamedEntityType": {
      "enum": [
        "company",
        "country",
        "location",
        "person"
      ],
      "title": "NamedEntityType",
      "type": "string"
    }
  },
  "properties": {
    "title": {
      "description": "Appropriate headline for the news article",
      "title": "Title",
      "type": "strin

In [9]:
EXTRACTION_PROMPT = """
Read the news article below and extract the information into a
JSON object. This is the schema you should use to extract the
data. It's provided in standard JSON Schema format.

```json
{news_article_schema}
```

Here is the news article to generate the JSON data for. The article
is surrounded by triple backticks.

```
{news_article}
```

Format your response as a JSON object surrounded by triple backticks
with the 'json' identifier like this:
```json
your JSON object here
```

Be sure to follow the provided schema exactly.
"""

If this notebook is being run in Github Codespaces, you'll need to set `OPENAI_API_KEY` as a codespace secret in user settings. If it's being run locally then you'll need to set the environment variable `OPENAI_API_KEY`.

In [21]:
def extract_article_data(article: str, model_name: str, client: OpenAI) -> bool:
    """Returns true if JSON is appropriately formatted"""

    formatted_prompt = EXTRACTION_PROMPT.format(
        news_article_schema=json.dumps(news_article_schema, indent=2),
        news_article=article
    )

    response = client.responses.create(
        model=model_name,
        input=formatted_prompt,
        temperature=0.0,
    )
    extracted_json = response.output_text

    if "```json" in extracted_json and "```" in extracted_json.split("```json", 1)[1]:
        json_content = extracted_json.split("```json", 1)[1].split("```", 1)[0].strip()

        try:
            parsed_json = json.loads(json_content)
            _ = NewsArticle.model_validate(parsed_json)
            return True
        except (json.JSONDecodeError, ValidationError) as e:
            return False
    else:
        return False

In [24]:
openai_client = OpenAI()

ARTICLE_COUNT = 20

success_count = 0
model_name="gpt-3.5-turbo"

for i in range(20):
    article = combined_df['text'].iloc[i]
    parsed_successfully = extract_article_data(article, model_name, openai_client)
    if parsed_successfully:
        success_count += 1

print(f"For {model_name} successfully validated {success_count} of {ARTICLE_COUNT} records")

For gpt-3.5-turbo successfully validated 18 of 20 records
