<a href="https://colab.research.google.com/github/dlapushin/warn_analysis/blob/main/WARN_Layoffs_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WARN Layoffs Data Analysis
Class definition for WarnFirmographics() with methods for:
* data ingestion (elt)
* LLM query and processing to group and normalize self-similar company names down to one name
* LLM query and processing to map companies to one of 14 NAICS industry sectors



In [78]:
!uv pip install --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib tiktoken groq tqdm

import pandas as pd
from collections import defaultdict
import os
import time
import pickle
import requests
import json
from tqdm.notebook import tqdm
import plotly.express as px
import pandas as pd
import numpy as np

from typing import List, Dict, Generator, Tuple, Any
from logging import exception
import tiktoken
import re
from groq import Groq
from google.colab import auth
auth.authenticate_user()

import gspread
from google.colab import auth
from google.auth import default


def save_object(obj, filename: str) -> None:
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

class WarnFirmographics:
    """Class to process a list of companies and classify them into industries using the Groq API."""
    def __init__(self, api_key: str, encoding_name: str, chunk_size: int = 100, start_chunk: int = 0):
        self.client = Groq(api_key=api_key)
        self.encoding = tiktoken.get_encoding(encoding_name)
        self.chunk_size = chunk_size
        self.start_chunk = start_chunk
        self.dict_master_response: Dict[int, Any] = {}

    def load_google_sheet(self, spreadsheet_id: str, range_name: str, credentials_path: str) -> pd.DataFrame:
        """Load data from a Google Sheet into a pandas DataFrame."""
        credentials = service_account.Credentials.from_service_account_file(credentials_path)
        service = build('sheets', 'v4', credentials=credentials)
        sheet = service.spreadsheets()
        result = sheet.values().get(spreadsheetId=spreadsheet_id, range=range_name).execute()
        values = result.get('values', [])

        if not values:
            print('No data found.')
            return pd.DataFrame()
        else:
            df = pd.DataFrame(values[1:], columns=values[0])
            return df

    def data_elt_colab(self, cutoff_date: str) -> None:

      # Authenticate with Google Colab
      auth.authenticate_user()
      creds, _ = default()

      gc = gspread.authorize(creds)

      spreadsheet_hist = 'https://docs.google.com/spreadsheets/d/1ayO8dl7sXaIYBAwkBGRUjbDms6MAbZFvvxxRp8IyxvY/edit?usp=sharing'
      spreadsheet_2025 = 'https://docs.google.com/spreadsheets/d/1Qx6lv3zAL9YTsKJQNALa2GqBLXq0RER2lHvzyx32pRs/edit?gid=0#gid=0'
      range_name = 'Sheet1!A1:M'
      worksheet_hist = gc.open_by_url(spreadsheet_hist)
      worksheet_2025 = gc.open_by_url(spreadsheet_2025)

      sheet_hist = worksheet_hist.values_get(range=range_name).values()
      sheet_2025 = worksheet_2025.values_get(range=range_name).values()

      sheet_hist_vals = sheet_hist.mapping['values']
      sheet_2025_vals = sheet_2025.mapping['values']

      df_hist = pd.DataFrame.from_records(sheet_hist_vals[1:], columns=sheet_hist_vals[0])
      df_2025 = pd.DataFrame.from_records(sheet_2025_vals[1:], columns=sheet_hist_vals[0])

      self.df_full = pd.concat([df_hist, df_2025], ignore_index=True)
      self.df_full['WARN Received Date'] = pd.to_datetime(df_full['WARN Received Date'])
      #df_full['WARN Received Date'] = df_full['WARN Received Date'].dt.strftime('%Y-%m-%d')
      self.df_full.sort_values('WARN Received Date', ascending=False, inplace=True)
      self.df_full = self.df_full[self.df_full['WARN Received Date'] >= cutoff_date]
      self.df_full['company_name_normalized'] = self.df_full['Company'].apply(self.normalize_text)
      print('#######')
      print('WARN data import complete')
      print('#######')

    # Helper function to remove punctuation and normalize text
    def normalize_text(self, text: str) -> str:
        # Convert to lowercase and remove punctuation
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        # Remove extra whitespace
        text = ' '.join(text.split())
        return text

    def data_elt(self) -> None:
        """Perform ETL on the WARN data."""
        # Load the Google Sheet
        spreadsheet_hist_id = '1ayO8dl7sXaIYBAwkBGRUjbDms6MAbZFvvxxRp8IyxvY' # From Google Sheets URL
        spreadsheet_2025_id = '1Qx6lv3zAL9YTsKJQNALa2GqBLXq0RER2lHvzyx32pRs' # From Google Sheets URL

        range_name = 'Sheet1!A1:M'
        credentials_path = '../warn-analytics.json'

        df_hist = warn_data.load_google_sheet(spreadsheet_hist_id, range_name, credentials_path)
        df_2025 = warn_data.load_google_sheet(spreadsheet_2025_id, range_name, credentials_path)
        self.df_full = pd.concat([df_hist , df_2025], ignore_index=True)

        self.df_full['WARN Received Date'] = pd.to_datetime(self.df_full['WARN Received Date'])
        self.df_full['WARN Received Date'] = self.df_full['WARN Received Date'].dt.strftime('%Y-%m-%d')
        self.df_full.sort_values('WARN Received Date', ascending=False, inplace=True)
        self.df_full = self.df_full[self.df_full['WARN Received Date'] >= '2022-01-01']
        self.df_full['company_name_normalized'] = self.df_full['Company'].apply(self.normalize_text)

    def group_by_starting_letter(self, column: str) -> None:
        """Group values in the DataFrame by the starting letter of the specified column."""
        """Allows for smarter iteration of all potential duplicates having same starting letter"""
        self.grouped_companies = defaultdict(list)
        list_of_companies_sorted = list(self.df_full[column].unique())
        list_of_companies_sorted.sort()

        for value in list_of_companies_sorted:
            if pd.notna(value):
                first_letter = value[0].lower()
                self.grouped_companies[first_letter].append(value)

    def convert_to_dict(self, json_str: str) -> Dict[str, str]:
        """Convert a JSON string to a dictionary."""
        json_objects = json_str.strip().split('\n')
        json_array_str = f"{','.join(json_objects)}"
        try:
            response_list = json.loads(json_array_str)
            response_dict = {item['company']: item['classification'] for item in response_list}
            return response_dict
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            print(json_array_str)
            return {}

    def create_llm_query_dedupe(self) -> None:
        """Create an LLM query for each group of values."""
        self.queries_dedupe = {}
        for letter, values in self.grouped_companies.items():
            company_list_str = ", ".join(values)
            query = f"""Evaluate the following comma-separated list of companies and determine which ones are likely duplicates.
                If more than one company has the same normalized name, consider them duplicates and "has_duplicates" should be "True".
                Return the result as lower-case text with double quotes in a JSON object using the format below. Only provide the JSON object, nothing else.:
                {{"original_company_name": {{"has_duplicates": "True" or "False", "normalized_company_name": "a deduplicated company name"}}}}.
                ---:
                {company_list_str}
                """
            self.queries_dedupe[letter] = query

    def make_api_call(self, query: str) -> Tuple[int, str]:
        """Make an API call with the given query."""
        #print('current query:', query)
        chat_completion = self.client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": query,
                }
            ],
            model="llama-3.3-70b-versatile",
            temperature=0,
            max_completion_tokens=32768,
            top_p=1,
            stream=False,
            stop=None,
            seed=123,
        )

        response = chat_completion.choices[0].message.content
        json_str = f"[{response}]"

        return json_str

    def dedupe_normalized_company_list(self) -> None:
        """Call LLM to evaluate company list and determine duplicates."""
        self.list_normalized_companies = []
        self.df_deduped_company_normalized = pd.DataFrame()

        print('#######')
        print('Starting dedupe of company names')
        print('#######')

        for key in tqdm(self.queries_dedupe.keys()):
            result = self.make_api_call(self.queries_dedupe[key])
            dict_result = json.loads(result)

            transformed_list = []
            for key_result, value_result in dict_result[0].items():
                #print(key_result, value_result)
                transformed_dict = {'original_company_name': key_result}
                transformed_dict.update(value_result)
                transformed_list.append(transformed_dict)

            df = pd.DataFrame(transformed_list)
            self.df_deduped_company_normalized = pd.concat([self.df_deduped_company_normalized, df], ignore_index=True)

        self.list_companies_deduped = list(set(self.df_deduped_company_normalized['normalized_company_name']))
        self.list_companies_deduped.sort()
        print('#######')
        print('Dedupe complete')
        print('#######')


    def create_llm_query_industry_classification(self, chunk: list) -> None:
      """Make an API call with the given chunk."""
      company_list_str = ""
      try:
          company_list_str = ", ".join(chunk)
      except Exception as e:
          print(e)
          print(chunk)
      query_classification = f"""Take the list of companies provided in the COMPANY-LIST below and classify each company into exactly one of the 14 industries shown below under INDUSTRIES. Do not use any other industry classifications.
      Return the result as a valid JSON array of objects, where each object contains the company name and its classification.
      Ensure the JSON is properly formatted with no extraneous characters or formatting issues.
      Example: [{{"company":"google", "classification": "Information"}}, {{"company":"amazon", "classification": "Retail Trade"}}]
      Only provide the JSON array, nothing else.
      INDUSTRIES:
      1. Construction
      2. Education and Health Services
      3. Finance and Insurance
      4. Information
      5. Leisure and Hospitality
      6. Manufacturing
      7. Mining, quarrying, and oil and gas extraction
      8. Professional and Business Services
      9. Real Estate,Rental and Leasing
      10. Retail Trade
      11. Transportation and Warehousing
      12. Utilities
      13. Wholesale Trade
      14. Other Services
      Ensure that each classification is one of the 14 industries listed above. Do not use any other categories.

      COMPANY-LIST:
      {company_list_str}
      """

      return query_classification

    # Function to divide the list into chunks of 100 companies
    def chunk_list(self, lst, chunk_size, start_chunk=0):
        """Divide the list into chunks of specified size, starting from a specific chunk."""

        for i in range(start_chunk * chunk_size, len(lst), chunk_size):
            list_focus = lst[i:i + chunk_size]
            cleaned_list = [x for x in list_focus if (x != 'nan') and (x != 'None')]

            yield cleaned_list #lst[i:i + chunk_size]

    def process_chunks(self, start_chunk=0, end_chunk=None) -> None:
        """Process the list of companies in chunks and make API calls."""
        if (start_chunk != 0):
            self.start_chunk = start_chunk
        else:
            self.start_chunk = 0

        self.ctr = self.start_chunk
        self.chunk_list_companies_deduped = self.chunk_list(self.list_companies_deduped, chunk_size=100, start_chunk=self.start_chunk)

        print('#######')
        print('Starting industry classification for all normalized companies in dataset')
        print('#######')

        self.dict_master_response = {}

        end_chunk = int(len(self.list_companies_deduped) / 100) + 1
        for chunk in tqdm(self.chunk_list_companies_deduped):
            while True:
                try:
                    query = self.create_llm_query_industry_classification(chunk)
                    self.classification_json_response = self.make_api_call(query)
                    self.classification_json_objects = self.classification_json_response.strip().split('\n')
                    self.classification_json_array_str = f"{''.join(self.classification_json_objects)}"
                    self.classification_json_array_str = self.classification_json_array_str.replace("}{" ,"}, {")
                    self.classification_list_of_dicts = json.loads(self.classification_json_array_str)

                    self.dict_master_response[self.ctr] = self.classification_list_of_dicts

                    if self.ctr % 10 == 0:
                        with open(f'dict_master_response.pkl', 'wb') as f:
                            pickle.dump(self.dict_master_response, f)
                        print(f"Saved dict_master_response.pkl")

                    self.ctr += 1
                    break  # Exit the WHILE loop only if the API call is successful
                except requests.exceptions.RequestException as e:
                    print(f"API call failed: {e}")
                    print("Waiting for 5 minutes before retrying...")
                    time.sleep(300)  # Wait for 5 minutes (300 seconds) before retrying

                if self.ctr > end_chunk:
                    save_object(self.dict_master_response, 'dict_master_response.pkl')
                    break


# Initialize the WarnFirmographics class
warn_data = WarnFirmographics(api_key=user_data.get('groq_secret'), encoding_name="cl100k_base")
warn_data.data_elt_colab(cutoff_date = '2022-01-01')
warn_data.group_by_starting_letter('company_name_normalized')
warn_data.create_llm_query_dedupe()
warn_data.dedupe_normalized_company_list()
warn_data.process_chunks()

[2mUsing Python 3.11.11 environment at: /usr[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mgoogle-api-python-client==2.162.0                                             [0m[2K[37m⠙[0m [2mgoogle-auth-httplib2==0.2.0                                                   [0m[2K[37m⠙[0m [2mgoogle-auth-oauthlib==1.2.1                                                   [0m[2K[37m⠙[0m [2mtiktoken==0.9.0                                                               [0m[2K[37m⠙[0m [2mgroq==0.18.0                                                                  [0m[2K[37m⠙[0m [2mtqdm==4.67.1                                                                  [0m[2K[37m⠙[0m [2mhttplib2==0.22.0 

  0%|          | 0/34 [00:00<?, ?it/s]

#######
Dedupe complete
#######
#######
Starting industry classification for all normalized companies in dataset
#######


0it [00:00, ?it/s]

Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl


In [133]:
def create_master_dataframe():  ## includes some cleanup in cases of minor LLM error
  valid_categories = [
                "Construction", "Education and Health Services", "Finance and Insurance", "Information",
                "Leisure and Hospitality", "Manufacturing", "Mining, quarrying, and oil and gas extraction",
                "Professional and Business Services", "Real Estate,Rental and Leasing", "Retail Trade",
                "Transportation and Warehousing", "Utilities", "Wholesale Trade", "Other Services"
            ]

  df_master_response = pd.DataFrame()

  for chunk in warn_data.dict_master_response.keys():
    try:
      df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk][0])
    except Exception as e:
      print(e)
      if(list(warn_data.dict_master_response[chunk][0].keys()) == ['companies']):
          df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk][0]['companies'])
      else:
        break
    df_master_response = pd.concat([df_master_response, df_chunk], ignore_index=True)

  df_enrich = warn_data.df_deduped_company_normalized.merge(df_master_response, left_on='normalized_company_name', right_on='company', how='left')
  df_full_enriched = warn_data.df_full.merge(df_enrich, left_on='company_name_normalized', right_on='original_company_name', how='left')
  df_full_enriched = df_full_enriched[df_full_enriched['Number of Workers'] != '']
  df_full_enriched['Number of Workers'] = df_full_enriched['Number of Workers'].astype(int)
  df_full_enriched = df_full_enriched[df_full_enriched['classification'].isin(valid_categories)]

  return df_full_enriched

df_full_enriched = create_master_dataframe()



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [146]:
df_full_enriched['pct_of_total'] = df_full_enriched['Number of Workers'] / df_full_enriched['Number of Workers'].sum() * 100

fig = px.treemap(df_full_enriched[df_full_enriched['Number of Workers']>0],
                 path=['classification', 'State'],
                 values='Number of Workers',
                 title='Interactive Treemap',
                 color_continuous_scale='Viridis'
                 )

# Update the layout to make the treemap bigger
fig.update_layout(
    width=1500,  # Set the width of the treemap
    height=1000,  # Set the height of the treemap
    title_font_size=24  # Set the font size of the title
)

fig.show()