<a href="https://colab.research.google.com/github/eli-jaffe/BLS-AI-Researcher-Demo/blob/main/BLS_AI_Researcher_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BLS AI Researcher Demo

This notebook contains a Bureau of Labor Statistics (BLS) research agent powered by AI, leveraging Anthropic's Claude model and the CrewAI framework. The agent is designed to take a question about the U.S. labor market or a specific BLS data series and provide analysis. For demo purposes, it is designed to run in Google Colab.

## 1) Installing libraries

First, we need to install extra packages not in the Colab environment by default.

In [None]:
!pip install langchain_community colab-xterm crewai crewai_tools anthropic

# It may ask you to restart, accept and run this cell again

## 2) AI Agents


## What is Crew AI?

To work with the Agents, we will use crewai.

This framework let us set up each Agent with its goals, tasks, tools, etc.

The AI Agents, in this case powered by Anthropic's Claude, will interact among them to fulfill tasks.

In this example we will set up the following Agents:

1.   Researcher
2.   Analyst
3.   BLS Librarian
4.   Data Visualizer


And we will give them a simple task for them to work together and solve.

The task is:

 **"Analyze the most recent data on total nonfarm payroll employment in the United States from 2023 through the end of 2025."**


## Setting the AI Agents

In [None]:
from google.colab import userdata, drive
import matplotlib.pyplot as plt
import seaborn as sns

from crewai import LLM, Agent, Task, Crew, Process
from crewai.tools import tool, BaseTool

import re

import requests
import json
import pandas as pd
import numpy as np
from datetime import date

from datetime import datetime
import os

# Define tools

In [None]:
# function to help us avoid passing improperly formed dataseries IDs to the BLS API
def is_valid_bls_series_id(series_id: str) -> bool:
    """
    Validate whether a given string is a well-formed BLS series ID.

    Known formats:
    - CES: CEU + 6-digit industry + 2-digit data type + 1-char seasonal
      e.g., CEU54130003S
    - JOLTS: JTU + 6-digit area + 6-digit industry + 2-digit data element + 1-char seasonal
      e.g., JTU00000000054111S
    """
    if not isinstance(series_id, str):
        return False

    patterns = [
        re.compile(r"^CE[SU][a-zA-Z0-9]{10}$"), # CES
        re.compile(r"^JT[SU]\d{6}\d{2}\d{5}\d{2}(HI|JO|LD|OS|QU|TS|UN|UO)[LR]$"),  # JOLTS
        # Add more patterns for LAUS, CPI, etc., if needed
    ]

    return any(p.match(series_id.upper()) for p in patterns)

In [None]:
# this tool will ping the BLS API for data on the requested series
class BLSFetcherTool(BaseTool):
    name: str = "get_bls_by_seriesid"
    description: str = (
        "Fetches job market data from the BLS API given a list of series IDs "
        "and an optional date range (startyear, endyear). Returns JSON-formatted dataframe."
    )

    def _run(
        self,
        seriesid: list,
        startyear: int = None,
        endyear: int = None,
        print_message: bool = False
    ) -> str:

        base_url = "https://api.bls.gov/publicAPI/v1/timeseries/data/"

        if isinstance(seriesid, str):
                seriesid = [seriesid]

        elif not isinstance(seriesid, list):
            seriesid = [seriesid]

        seriesid = [s.upper() for s in seriesid]

        if all(is_valid_bls_series_id(s) for s in seriesid):
            print("All series IDs are valid.")
            pass

        else:
            raise ValueError(f"Invalid BLS Series ID format in {seriesid}")

        try:
            print(f"SeriesID: {seriesid} ({type(seriesid)})")
            print(f"Startyear: {startyear} ({type(startyear)})")
            print(f"Endyear: {endyear} ({type(endyear)})")

            headers = {'Content-type': 'application/json'}

            # Calculate date range
            if startyear is not None:
                if endyear is None:
                    endyear = min(startyear + 10, date.today().year)
                payload = {
                    "seriesid": seriesid,
                    "startyear": startyear,
                    "endyear": endyear
                }
                print('Pinging BLS API with year range...')
                print(f'base_url: {base_url}')
                print(f'payload: {json.dumps(payload)}')
            else:
                payload = {"seriesid": seriesid}
                print('Pinging BLS API without year range...')
                print(f'base_url: {base_url}')
                print(f'payload: {json.dumps(payload)}')

            # API request
            response = requests.post(base_url, data=json.dumps(payload), headers=headers)
            json_data = response.json()

            print(json_data.get("status", "No status"))
            if print_message:
                print(json_data.get("message", "No message"))

            # Convert to dataframe
            df = pd.json_normalize(json_data['Results']['series'], record_path='data', meta='seriesID')

            # Feature engineering
            df = df.assign(
                series_type=lambda d: d.seriesID.str[-3:],
                naics_2digit=lambda d: d.seriesID.str[3:5],
                series_type_month=lambda d: d.series_type + d.year + d.period,
                year_month=lambda d: pd.to_datetime(d.year.astype(str) + '-' + d.periodName.astype(str),
                                                    format='%Y-%B') + pd.tseries.offsets.MonthEnd(),
                series_entry_id=lambda d: d.seriesID + d.year_month.astype(str)
            )

            df['series_entry_id'] = np.where(
                df.footnotes.astype(str).str.contains('P'),
                df.series_entry_id + 'p',
                df.series_entry_id
            )

            df['value'] = df['value'].astype(float)

            print(f"Data found! Returning {len(df)} rows.")

            df_json = df.to_json()

            df_string = json.dumps(df_json)

            return df_string

        except Exception as e:
            print(f"Tool execution failed: {e}")
            return f"Tool execution failed: {e}"

    def run(self, **kwargs):
        return self._run(**kwargs)

In [None]:
# we'll use the SavePlotTool to save any outputted visualizations to our google drive
class SavePlotTool(BaseTool):
    name: str = "save_plot"
    description: str = (
        "Saves a matplotlib plot to Google Drive. "
        "If no `plt` object is passed, executes the given code to generate one."
    )

    def _run(self, input_string: str = None, **kwargs) -> str:

        try:
          print(input_string)
          print(type(input_string))

        except:
          print('No variable "input_string" found')

        try:
            # Mount Google Drive
            drive.mount('/content/drive', force_remount=False)

            # Generate output file path
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            file_path = f"/content/drive/MyDrive/Colab Notebooks/crewai/plot_{timestamp}.png"

            # Use provided plt directly if available
            if "plt" in kwargs and kwargs["plt"] is not None:
                plt_obj = kwargs["plt"]
                plt_obj.savefig(file_path)
                return f"Plot saved to {file_path}"

            # Parse and run code from input_string if available
            elif input_string:
                try:
                  code_dict = json.loads(input_string)
                  code_string = code_dict.get("code", "")

                except:
                  code_string = input_string

                if not code_string:
                    return "❌ Error: No 'code' found in input string."

                # Execute the code using the current plt context
                exec_globals = {"plt": plt}
                exec(code_string, exec_globals)

                return f"Plot created and saved to {file_path}"

            else:
                return "❌ Error: Provide either a `plt` object or an `input_string` containing a 'code' key."

        except Exception as e:
            return f"❌ Failed to save plot: {str(e)}"

    def run(self, input_string=None, **kwargs):
        return self._run(input_string=input_string, **kwargs)


### Define our agents

In [None]:
# we will use Claude 3-5 Haiku for this
llm = LLM(
        model="claude-3-5-haiku-20241022",
        base_url="https://api.anthropic.com",
        api_key=userdata.get("eli-gcolab-api-key")
    )

fetch_tool = BLSFetcherTool()


class ResearcherAgent(Agent):
    def __init__(self, llm, tools=None):
        super().__init__(
            role="Researcher",
            goal="You will be given a task to pull together data. First you will need to determine how to pull the data. Then you will gather the requested data and return it in JSON format.",
            backstory="""You are an experienced data analyst. Your provide data and analysis for your team.
        When presented with a task, you will use your tools to pull the data needed and then hand it off for analysis.
        """,
            tools=tools,
            llm=llm)


class AnalystAgent(Agent):
    def __init__(self, llm, tools=None):
        super().__init__(
            role="Analyst",
            goal="You will receive a string representation of a JSON dataframe from the researcher. Analyze the information found by the researcher. You may ask the data visualizer for supporting visuals if required. If you do so, pass the string representation of the JSON data you are working with to the data visualizer.",
            backstory="""You are an eperienced data analyst. When you receive data from the researcher, you will analyze it and provide insights.""",
            tools=tools,
            allow_delegation=True,
            llm=llm)


class BLSLibrarianAgent(Agent):
    def __init__(self, llm, tools=None):
        super().__init__(
            role="BLS Librarian",
            goal="Determine the most appropriate Bureau of Labor Statistics (BLS) data codes. You may return a single seriesID as a string, or multiple seriesIDs as a list of strings. Each. string you return must be a valid BLS seriesID.",
            backstory="""You are an expert librarian for the BLS.
        When you receive a request for BLS data, you generate the most relevant seriesID that can be used in the API to pull data. Return only the final seriesID code.
        Useful information:
            - details on job openings, turnover, and quits typically comes from 'JOLTS'.
            - details on total employment typically comes from 'CES'. Requests for 'total' or 'current' employment in the US will most likely refer to data from the CES series.


        ## How to build a seriesID:

        Each BLS survey (CES, JOLTS, etc) follows a specific pattern for seriesID construction.

        1. A survey prefix:
        - CES: 'CE'
        - JOLTS: 'JT'

        2. A seasonal adjustment code indicating the adjustment of time series data to eliminate the effect of intrayear variations which tend to occur during the same period on an annual basis
        - Seasonally adjusted: 'S'
        - Unadjusted: 'U'

        3. Series specifc information. This is where the pattern starts to diverge, based on series type.

        **For CES**:

        The Current Employment Statistics (CES) report provides employment, hours, and earnings estimates based on payroll records of business establishments. The series pattern continues:

        a) 8-digit sequence of supersector and industry code.
          - For the most part, you will use the 2-digit supersector code only. This list is:
          00	Total nonfarm
          05	Total private
          06	Goods-producing
          07	Service-providing
          08	Private service-providing
          10	Mining and logging
          20	Construction
          30	Manufacturing
          31	Durable Goods
          32	Nondurable Goods
          40	Trade, transportation, and utilities
          41	Wholesale trade
          42	Retail trade
          43	Transportation and warehousing
          44	Utilities
          50	Information
          55	Financial activities
          60	Professional and business services
          65	Private education and health services
          70	Leisure and hospitality
          80	Other services
          90	Government

          If not otherwise specified,  default to the Total nonfarm code '00'.

          - The final 6-digits will simply be the code for Total nonfarm ('000000'). If the user asks for a specific industry code that you know, you may. supply the 6-digit code here after the supersector code.

          Putting this together gives the 2 digit supersector code detailed previously and the 6-digit industry code, for a total of 8 digits.

        b) Finally, a 2-digit data type code. The possible values are:
          01	ALL EMPLOYEES, THOUSANDS
          02	AVERAGE WEEKLY HOURS OF ALL EMPLOYEES
          03	AVERAGE HOURLY EARNINGS OF ALL EMPLOYEES
          04	AVERAGE WEEKLY OVERTIME HOURS OF ALL EMPLOYEES
          06	PRODUCTION AND NONSUPERVISORY EMPLOYEES, THOUSANDS
          07	AVERAGE WEEKLY HOURS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES
          08	AVERAGE HOURLY EARNINGS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES
          09	AVERAGE WEEKLY OVERTIME HOURS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES
          10	WOMEN EMPLOYEES, THOUSANDS
          11	AVERAGE WEEKLY EARNINGS OF ALL EMPLOYEES
          12	AVERAGE WEEKLY EARNINGS OF ALL EMPLOYEES, 1982-1984 DOLLARS
          13	AVERAGE HOURLY EARNINGS OF ALL EMPLOYEES, 1982-1984 DOLLARS
          15	AVERAGE HOURLY EARNINGS OF ALL EMPLOYEES, EXCLUDING OVERTIME
          16	INDEXES OF AGGREGATE WEEKLY HOURS OF ALL EMPLOYEES, 2007=100
          17	INDEXES OF AGGREGATE WEEKLY PAYROLLS OF ALL EMPLOYEES, 2007=100
          19	AVERAGE WEEKLY HOURS OF ALL EMPLOYEES, QUARTERLY AVERAGES, SEASONALLY ADJUSTED
          20	AVERAGE WEEKLY OVERTIME HOURS OF ALL EMPLOYEES, QUARTERLY AVERAGES, SEASONALLY ADJUSTED
          21	DIFFUSION INDEXES, 1-MONTH SPAN, SEASONALLY ADJUSTED
          22	DIFFUSION INDEXES, 3-MONTH SPAN, SEASONALLY ADJUSTED
          23	DIFFUSION INDEXES, 6-MONTH SPAN, SEASONALLY ADJUSTED
          24	DIFFUSION INDEXES, 12-MONTH SPAN, NOT SEASONALLY ADJUSTED
          25	ALL EMPLOYEES, QUARTERLY AVERAGES, SEASONALLY ADJUSTED, THOUSANDS
          26	ALL EMPLOYEES, 3-MONTH AVERAGE CHANGE, SEASONALLY ADJUSTED, THOUSANDS
          30	AVERAGE WEEKLY EARNINGS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES
          31	AVERAGE WEEKLY EARNINGS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, 1982-84 DOLLARS
          32	AVERAGE HOURLY EARNINGS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, 1982-84 DOLLARS
          33	AVERAGE HOURLY EARNINGS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, EXCLUDING OVERTIME
          34	INDEXES OF AGGREGATE WEEKLY HOURS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, 2002=100
          35	INDEXES OF AGGREGATE WEEKLY PAYROLLS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, 2002=100
          36	AVERAGE WEEKLY HOURS, PRODUCTION/NONSUPERVISORY EMPLOYEES, QUARTERLY AVERAGES, SEASONALLY ADJUSTED
          37	AVERAGE WEEKLY OVERTIME HOURS,PRODUCTION/NONSUPERVISORY EMPLOYEES,QUARTERLY AVG,SEASONALLY ADJUSTED
          56	AGGREGATE WEEKLY HOURS OF ALL EMPLOYEES, THOUSANDS
          57	AGGREGATE WEEKLY PAYROLLS OF ALL EMPLOYEES, THOUSANDS
          58	AGGREGATE WEEKLY OVERTIME HOURS OF ALL EMPLOYEES, THOUSANDS
          81	AGGREGATE WEEKLY HOURS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, THOUSANDS
          82	AGGREGATE WEEKLY PAYROLLS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, THOUSANDS
          83	AGGREGATE WEEKLY OVERTIME HOURS OF PRODUCTION AND NONSUPERVISORY EMPLOYEES, THOUSANDS
          C1	FIRST CLOSING COLLECTION RATE, NOT SEASONALLY ADJUSTED
          C2	SECOND CLOSING COLLECTION RATE, NOT SEASONALLY ADJUSTED
          C3	THIRD CLOSING COLLECTION RATE, NOT SEASONALLY ADJUSTED
          RR	THIRD CLOSING RESPONSE RATE, NOT SEASONALLY ADJUSTED

        If not indicated, default to the data type code '01' for all employees, thousands.

        The end result for CES will be a 13 digit string starting with 'CE'.

        ** For JOLTS**:

       JOLTS collects data on Total Employment, Job Openings, Hires, Quits, Layoffs & Discharges, and Other Separations. The series pattern continues:

        a) 6-digit industry code. This is different from the industry codes used for the CES. Possible values are:
          000000	Total nonfarm
          100000	Total private
          110099	Mining and logging
          230000	Construction
          300000	Manufacturing
          320000	Durable goods manufacturing
          340000	Nondurable goods manufacturing
          400000	Trade, transportation, and utilities
          420000	Wholesale trade
          440000	Retail trade
          480099	Transportation, warehousing, and utilities
          510000	Information
          510099	Financial activities
          520000	Finance and insurance
          530000	Real estate and rental and leasing
          540099	Professional and business services
          600000	Private education and health services
          610000	Private educational services
          620000	Health care and social assistance
          700000	Leisure and hospitality
          710000	Arts, entertainment, and recreation
          720000	Accommodation and food services
          810000	Other services
          900000	Government
          910000	Federal
          920000	State and local
          923000	State and local government education
          929000	State and local government, excluding education

        b) 2-digits for state. The state code is 2-digits starting with '00' for Total US and continuing sequentially by alphabetical order by state (Alabama is '01', Alaska is '02', etc.)
          Unless otherwise indicated, choose the state code '00' for total US.

        c) 5-digits for area. This will always. be '00000'

        d) 2-digits for size class. Default to '00'

        e) 2-digits for data element. Possible values are:
          HI	Hires
          JO	Job openings
          LD	Layoffs and discharges
          OS	Other separations
          QU	Quits
          TS	Total separations
          UN	Unemployment rate
          UO	Unemployed persons per job opening ratio

        f) final 1-digit code for Rate or Level. Default to Level.
          L	Level - In Thousands
          R	Rate

        The end result for JOLTS will be a 21 digit string starting with 'JT' and ending with 'L' or 'U'.

        """,
            tools=tools,
            allow_delegation=True,
            llm=llm)


class DataVisualizerAgent(Agent):
    def __init__(self, llm, tools=None):
        super().__init__(
            role="Data Visualizer",
            goal="To collaborate with your colleages and provide clear and engaging data visualizations. You receive a dataframe and a summary of key trends. You should save a visual that supports the key trends.",
            backstory="""You are an expert in data visualization and storytelling. You receive data and commentary from your colleagues and create elegant charts that clearly support the key takeways.""", # Return the plot itself (final call to the plt) to save it. Your tool accepts a ready to plt.show() plt which it will save.""",
            tools=tools,
            allow_delegation=True,
            verbose=True,
            llm=llm)

class ChartSelectorAgent(Agent):
    def __init__(self, tools, llm):
        super().__init__(
            role="Chart Type Selector",
            goal="""Analyze the JSON data (represented in a string variable) and select the best chart type (bar, scatter, line, histogram, area, etc). Return structured JSON instructions:
              "chart_type": "chart_type_name",
              "x": "column_name",
              "y": "column_name",
              "data": "dataframe"
            """,
            backstory="A visual reasoning assistant that recommends how to plot data.",
            tools=tools,
            llm=llm,
            allow_delegation=True)


class PlotBuilderAgent(Agent):
    def __init__(self, tools, llm):
        super().__init__(
            role="Plot Builder",
            goal="Create and save the best possible plot from structured instructions",
            backstory="Turns JSON instructions on chart type and visualization goal into beautiful plots. ",
            tools=tools,
            llm=llm)


In [None]:
# Initialize tools
fetch_tool = BLSFetcherTool()
save_tool = SavePlotTool()

# Initialize agents
bls_librarian = BLSLibrarianAgent(llm=llm)
researcher = ResearcherAgent(llm=llm, tools=[fetch_tool])
analyst = AnalystAgent(llm=llm)
data_visualizer = DataVisualizerAgent(llm=llm)
chart_selector = ChartSelectorAgent(llm=llm, tools=None)
plot_builder = PlotBuilderAgent(llm=llm, tools=[save_tool])

### Define the tasks  

In [None]:
task1 = Task(
description=("""
        Find the most recent data on total nonfarm payroll employment in the United States from 2023 through the end of 2025.
        Your ONLY job is to return data in JSON format.
        Do NOT try to interpret or modify the output.
        Simply return the result as your final answer.
    """
    ),
    expected_output="JSON representing the dataframe you pulled together.",
    agent=researcher,
    max_iter=3
)

task2 = Task(
    description=("""You have received JSON with human capital data. Analyze the data and identify key trends.
    Identify key takeaways from the data.
    """),
    expected_output="""
    A summary of key trends""",
    max_iter=5,
    agent=analyst
)

task3 = Task(
    description="Use the JSON data and key trends to determine the best chart type and x/y columns for plotting. Once you've determined the appropriate chart type, create and save a compelling visual that supports the key trends.",
    expected_output='A saved image file along with the key trends.',
    agent=data_visualizer,
    context = [task1, task2]
)

## Running the task with the agents

In [None]:
crew = Crew(
    agents=[researcher, bls_librarian, analyst, data_visualizer, chart_selector, plot_builder],
    tasks=[task1,  task2, task3],
    verbose=True, # This way we can see the agents thoughts and messages
    process=Process.sequential)


result = crew.kickoff()

In [None]:
print(result)