In [7]:
from enum import StrEnum
import json

from datasets import load_dataset
import pandas as pd
from pydantic import BaseModel, Field


In [2]:

dataset = load_dataset("SetFit/bbc-news")

train_df = pd.DataFrame(dataset['train'])
test_df = pd.DataFrame(dataset['test'])
combined_df = pd.concat([train_df, test_df])

print("Dataset Overview:")
print(f"Dataset size: {len(combined_df)}")
print("\nLabel distribution in training set:")
print(combined_df['label_text'].value_counts())
print("\nSample text from training set:")
print(combined_df['text'].iloc[0][:200])

train_df['text_length'] = train_df['text'].str.len()
print(f"\nText length mean: {train_df['text_length'].mean():.2f}")

Dataset Overview:
Dataset size: 2225

Label distribution in training set:
label_text
sport            511
business         510
politics         417
tech             401
entertainment    386
Name: count, dtype: int64

Sample text from training set:
wales want rugby league training wales could follow england s lead by training with a rugby league club.  england have already had a three-day session with leeds rhinos  and wales are thought to be in

Text length mean: 2288.95


In [5]:
class ArticleCategory(StrEnum):
    BUSINESS = "business"
    ENTERTAINMENT = "entertainment"
    POLITICS = "politics"
    SPORT = "sport"
    TECH = "tech"

class NamedEntityType(StrEnum):
    COMPANY = "company"
    COUNTRY = "country"
    LOCATION = "location"
    PERSON = "person"

class NamedEntity(BaseModel):
    name: str = Field(description="Name of the person, company, etc.")
    type: NamedEntityType = Field(description="Type of named entity")

class NewsArticle(BaseModel):
    title: str = Field(description="Appropriate headline for the news article")
    category: ArticleCategory = Field(description="Category of news that the article belongs to")
    mentioned_entities: list[NamedEntity] = Field(description="List of all named entities in the article", default_factory=list)
    summary: str = Field(description="Brief summary of the article content.")



In [9]:
news_article_schema = NewsArticle.model_json_schema()
print(json.dumps(news_article_schema, indent=2))

{
  "$defs": {
    "ArticleCategory": {
      "enum": [
        "business",
        "entertainment",
        "politics",
        "sport",
        "tech"
      ],
      "title": "ArticleCategory",
      "type": "string"
    },
    "NamedEntity": {
      "properties": {
        "name": {
          "description": "Name of the person, company, etc.",
          "title": "Name",
          "type": "string"
        },
        "type": {
          "$ref": "#/$defs/NamedEntityType",
          "description": "Type of named entity"
        }
      },
      "required": [
        "name",
        "type"
      ],
      "title": "NamedEntity",
      "type": "object"
    },
    "NamedEntityType": {
      "enum": [
        "company",
        "country",
        "location",
        "person"
      ],
      "title": "NamedEntityType",
      "type": "string"
    }
  },
  "properties": {
    "title": {
      "description": "Appropriate headline for the news article",
      "title": "Title",
      "type": "strin