In [None]:
import pandas as pd
from collections import defaultdict
import os
import time
import pickle
import requests
import json
from typing import List, Dict, Generator, Tuple, Any
import tiktoken
from groq import Groq
from google.oauth2 import service_account
from googleapiclient.discovery import build
import re

def save_object():
    with open(f'warn_data.pkl', 'wb') as f:
        pickle.dump(warn_data, f)

class WarnFirmographics:
    """Class to process a list of companies and classify them into industries using the Groq API."""
    def __init__(self, api_key: str, encoding_name: str, chunk_size: int = 100, start_chunk: int = 0):
        self.client = Groq(api_key=api_key)
        self.encoding = tiktoken.get_encoding(encoding_name)
        self.chunk_size = chunk_size
        self.start_chunk = start_chunk
        self.dict_master_response: Dict[int, Any] = {}

    def load_google_sheet(self, spreadsheet_id: str, range_name: str, credentials_path: str) -> pd.DataFrame:
        """Load data from a Google Sheet into a pandas DataFrame."""
        credentials = service_account.Credentials.from_service_account_file(credentials_path)
        service = build('sheets', 'v4', credentials=credentials)
        sheet = service.spreadsheets()
        result = sheet.values().get(spreadsheetId=spreadsheet_id, range=range_name).execute()
        values = result.get('values', [])

        if not values:
            print('No data found.')
            return pd.DataFrame()
        else:
            df = pd.DataFrame(values[1:], columns=values[0])
            return df

    # Helper function to remove punctuation and normalize text
    def normalize_text(self, text: str) -> str:
        # Convert to lowercase and remove punctuation
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        # Remove extra whitespace
        text = ' '.join(text.split())
        return text
    
    def data_elt(self) -> None:
        """Perform ETL on the WARN data."""
        # Load the Google Sheet
        spreadsheet_hist_id = '1ayO8dl7sXaIYBAwkBGRUjbDms6MAbZFvvxxRp8IyxvY' # From Google Sheets URL
        spreadsheet_2025_id = '1Qx6lv3zAL9YTsKJQNALa2GqBLXq0RER2lHvzyx32pRs' # From Google Sheets URL

        range_name = 'Sheet1!A1:M'  
        credentials_path = '../warn-analytics.json'

        df_hist = warn_data.load_google_sheet(spreadsheet_hist_id, range_name, credentials_path)
        df_2025 = warn_data.load_google_sheet(spreadsheet_2025_id, range_name, credentials_path)
        self.df_full = pd.concat([df_hist , df_2025], ignore_index=True)
        
        self.df_full['WARN Received Date'] = pd.to_datetime(self.df_full['WARN Received Date'])
        self.df_full['WARN Received Date'] = self.df_full['WARN Received Date'].dt.strftime('%Y-%m-%d')
        self.df_full.sort_values('WARN Received Date', ascending=False, inplace=True)
        self.df_full = self.df_full[self.df_full['WARN Received Date'] >= '2022-01-01']
        self.df_full['company_name_normalized'] = self.df_full['Company'].apply(self.normalize_text)

    def group_by_starting_letter(self, column: str) -> None:
        """Group values in the DataFrame by the starting letter of the specified column."""
        """Allows for smarter iteration of all potential duplicates having same starting letter"""
        self.grouped_companies = defaultdict(list)
        list_of_companies_sorted = list(self.df_full[column].unique())
        list_of_companies_sorted.sort()
        
        for value in list_of_companies_sorted:
            if pd.notna(value):
                first_letter = value[0].lower()
                self.grouped_companies[first_letter].append(value)

    def convert_to_dict(self, json_str: str) -> Dict[str, str]:
        """Convert a JSON string to a dictionary."""
        json_objects = json_str.strip().split('\n')
        json_array_str = f"{','.join(json_objects)}"
        try:
            response_list = json.loads(json_array_str)
            response_dict = {item['company']: item['classification'] for item in response_list}
            return response_dict
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            print(json_array_str)
            return {}

    def create_llm_query_dedupe(self) -> None:
        """Create an LLM query for each group of values."""
        self.queries_dedupe = {}
        for letter, values in self.grouped_companies.items():
            company_list_str = ", ".join(values)
            query = f"""Evaluate the following comma-separated list of companies and determine which ones are likely duplicates.
                If more than one company has the same normalized name, consider them duplicates and "has_duplicates" should be "True".                  
                Return the result as lower-case text with double quotes in a JSON object using the format below. Only provide the JSON object, nothing else.: 
                {{"original_company_name": {{"has_duplicates": "True" or "False", "normalized_company_name": "a deduplicated company name"}}}}.
                ---:
                {company_list_str}
                """
            self.queries_dedupe[letter] = query

    def make_api_call(self, query: str) -> Tuple[int, str]:
        """Make an API call with the given query."""
        #print('current query:', query)
        chat_completion = self.client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": query,
                }
            ],
            model="llama-3.3-70b-versatile",
            temperature=0,
            max_completion_tokens=32768,
            top_p=1,
            stream=False,
            stop=None,
            seed=123,
        )

        response = chat_completion.choices[0].message.content
        json_str = f"[{response}]"
    
        return json_str
        
    def dedupe_normalized_company_list(self) -> None:
        """Call LLM to evaluate company list and determine duplicates."""
        self.list_normalized_companies = []
        self.df_deduped_company_normalized = pd.DataFrame()

        for key in self.queries_dedupe.keys():
            result = self.make_api_call(self.queries_dedupe[key])
            dict_result = json.loads(result)

            transformed_list = []
            for key_result, value_result in dict_result[0].items():
                print(key_result, value_result)
                transformed_dict = {'original_company_name': key_result}
                transformed_dict.update(value_result)
                transformed_list.append(transformed_dict)

            df = pd.DataFrame(transformed_list)
            self.df_deduped_company_normalized = pd.concat([self.df_deduped_company_normalized, df], ignore_index=True)
            
        # for keys, values in dict_result[0].items():
        #     self.list_normalized_companies.append(values['normalized_company_name'])    

        self.list_companies_deduped = list(set(self.df_deduped_company_normalized['normalized_company_name']))
        self.list_companies_deduped.sort()


    def create_llm_query_industry_classification(self, chunk: list) -> None:
        """Make an API call with the given chunk."""
        #print('current chunk:' , chunk)
        company_list_str = ""
        try:
            company_list_str = ", ".join(chunk)
        except Exception as e:
            print(e)
            print(chunk)
        query_classification = f"""Take the list of companies provided in the COMPANY-LIST below and classify each company into exactly one of the following 14 possible industries shown below under INDUSTRIES.  
            Return result as a JSON object: example {{"company":"google", "classification": "Information"}}.
            Only provide the JSON object, nothing else
            INDUSTRIES: 
            1. Construction
            2. Education and Health Services
            3. Finance and Insurance
            4. Information
            5. Leisure and Hospitality
            6. Manufacturing
            7. Mining, quarrying, and oil and gas extraction
            8. Professional and Business Services
            9. Real Estate,Rental and Leasing
            10. Retail Trade
            11. Transportation and Warehousing
            12. Utilities
            13. Wholesale Trade
            14. Other Services

            COMPANY-LIST:
            {company_list_str}
            """
        
        return query_classification

    # Function to divide the list into chunks of 100 companies
    def chunk_list(self, lst, chunk_size, start_chunk=0):
        """Divide the list into chunks of specified size, starting from a specific chunk."""

        for i in range(start_chunk * chunk_size, len(lst), chunk_size):
            list_focus = lst[i:i + chunk_size]
            cleaned_list = [x for x in list_focus if (x != 'nan') and (x != 'None')]

            yield cleaned_list #lst[i:i + chunk_size]

    def process_chunks(self, start_chunk=0, end_chunk=None) -> None:
        """Process the list of companies in chunks and make API calls."""
        if (start_chunk != 0):
            self.start_chunk = start_chunk
        else:
            self.start_chunk = 0
        
        self.ctr = self.start_chunk
        self.chunk_list_companies_deduped = self.chunk_list(self.list_companies_deduped, chunk_size=100, start_chunk=self.start_chunk)
        
        self.dict_master_response = {}

        end_chunk = int(len(self.list_companies_deduped) / 100) + 1
        for chunk in self.chunk_list_companies_deduped:
            while True:
                try:
                    query = self.create_llm_query_industry_classification(chunk)
                    self.classification_json_response = self.make_api_call(query)
                    self.classification_json_objects = self.classification_json_response.strip().split('\n')
                    self.classification_json_array_str = f"{''.join(self.classification_json_objects)}"
                    self.classification_json_array_str = self.classification_json_array_str.replace("}{" ,"}, {")
                    self.classification_list_of_dicts = json.loads(self.classification_json_array_str)
                    
                    self.dict_master_response[self.ctr] = self.classification_list_of_dicts

                    if self.ctr % 10 == 0:
                        with open(f'dict_master_response.pkl', 'wb') as f:
                            pickle.dump(self.dict_master_response, f)
                        print(f"Saved dict_master_response.pkl")
                    
                    self.ctr += 1
                    break  # Exit the WHILE loop only if the API call is successful
                except requests.exceptions.RequestException as e:
                    print(f"API call failed: {e}")
                    print("Waiting for 5 minutes before retrying...")
                    time.sleep(300)  # Wait for 5 minutes (300 seconds) before retrying
            
                if self.ctr > end_chunk:
                    with open(f'dict_master_response.pkl', 'wb') as f:
                        pickle.dump(self.dict_master_response, f)
                    break

# Example usage
#if __name__ == "__main__":
    # Change the working directory
# os.chdir('/Users/daniel.lapushin6sense.com/Documents/GitHub/trade_advisor/llm_files')
# print(os.getcwd())

# Initialize the WarnFirmographics class
warn_data = WarnFirmographics(api_key='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', encoding_name="cl100k_base")
warn_data.data_elt()
warn_data.group_by_starting_letter('company_name_normalized')
warn_data.create_llm_query_dedupe()
warn_data.dedupe_normalized_company_list()
warn_data.process_chunks()
save_object()


1 click logistics {'has_duplicates': 'False', 'normalized_company_name': '1 click logistics'}
10 roads express llc {'has_duplicates': 'False', 'normalized_company_name': '10 roads express'}
1266 apartment corp {'has_duplicates': 'False', 'normalized_company_name': '1266 apartment corp'}
1888 mills llc {'has_duplicates': 'False', 'normalized_company_name': '1888 mills'}
2024 democratic national convention committee {'has_duplicates': 'False', 'normalized_company_name': '2024 democratic national convention committee'}
23andme inc {'has_duplicates': 'False', 'normalized_company_name': '23andme inc'}
247ai {'has_duplicates': 'False', 'normalized_company_name': '247ai'}
2seventy bio inc {'has_duplicates': 'False', 'normalized_company_name': '2seventy bio inc'}
360 behavioral health {'has_duplicates': 'False', 'normalized_company_name': '360 behavioral health'}
365 delivery inc {'has_duplicates': 'False', 'normalized_company_name': '365 delivery inc'}
3m hq {'has_duplicates': 'True', 'normal

In [None]:
df_master_response = pd.DataFrame()

for chunk in warn_data.dict_master_response.keys():
    if(list(warn_data.dict_master_response[chunk][0].keys()) == ['companies']):
        df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk][0]['companies'])
    else:
        df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk])
    
    df_master_response = pd.concat([df_master_response, df_chunk], ignore_index=True)


Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl
Saved dict_master_response.pkl


In [None]:
df_master_response = pd.DataFrame()

for chunk in warn_data.dict_master_response.keys():
    if(list(warn_data.dict_master_response[chunk][0].keys()) == ['companies']):
        df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk][0]['companies'])
    else:
        df_chunk = pd.DataFrame(warn_data.dict_master_response[chunk])
    
    df_master_response = pd.concat([df_master_response, df_chunk], ignore_index=True)

df_master_response.to_csv('warn_firmographics.csv', index=False)


In [486]:
#result = warn_data.make_api_call(warn_data.queries_dedupe['k'])
#dict_result = json.loads(result)


# Transform the list of dictionaries into a format suitable for a DataFrame
transformed_list = []
for key, value in dict_result[0].items():
    print(key, value)
    transformed_dict = {'original_company_name': key}
    transformed_dict.update(value)
    transformed_list.append(transformed_dict)

df = pd.DataFrame(transformed_list)

#     print(item)
#     print(dict_result)
    # for key, value in item.items():
    #     transformed_dict = {'original_company_name': key}
    #     transformed_dict.update(value)
    #     transformed_list.append(transformed_dict)

# Convert the transformed list into a DataFrame
#df = pd.DataFrame(transformed_list)
# Print the DataFrame
#print(df)
#pd.DataFrame(dict_result[0].items())  


kaanapali beach hotel and the plantation inn {'has_duplicates': 'False', 'normalized_company_name': 'kaanapali beach hotel and the plantation inn'}
kai management services llc {'has_duplicates': 'True', 'normalized_company_name': 'kai management services'}
kaiser aluminum sherman {'has_duplicates': 'False', 'normalized_company_name': 'kaiser aluminum'}
kaiser foundation hospitals {'has_duplicates': 'True', 'normalized_company_name': 'kaiser foundation hospitals'}
kaiser foundation hospitals arrow {'has_duplicates': 'True', 'normalized_company_name': 'kaiser foundation hospitals'}
kaiser foundation hospitals canyon {'has_duplicates': 'True', 'normalized_company_name': 'kaiser foundation hospitals'}
kaiser foundation hospitals fair oaks {'has_duplicates': 'True', 'normalized_company_name': 'kaiser foundation hospitals'}
kaiser foundation hospitals harrison {'has_duplicates': 'True', 'normalized_company_name': 'kaiser foundation hospitals'}
kaiser foundation hospitals lakeside {'has_dupli

Unnamed: 0,original_company_name,has_duplicates,normalized_company_name
0,kaanapali beach hotel and the plantation inn,False,kaanapali beach hotel and the plantation inn
1,kai management services llc,True,kai management services
2,kaiser aluminum sherman,False,kaiser aluminum
3,kaiser foundation hospitals,True,kaiser foundation hospitals
4,kaiser foundation hospitals arrow,True,kaiser foundation hospitals
...,...,...,...
116,kuka toledo production operations llc,False,kuka toledo production operations
117,kunzler company inc,False,kunzler company
118,kuubix global llc,False,kuubix global
119,kyowa kirin inc,False,kyowa kirin


In [None]:
#!pip install google-api-python-client google-auth-httplib2 google-auth-oauthlib
import pandas as pd
import json

# df_normalized_companies = pd.read_csv('/Users/daniel.lapushin6sense.com/Documents/GitHub/trade_advisor/llm_files/normalized_companies.csv')
# #list(df_normalized_companies['normalized'])
# df_normalized_companies.iloc[0:4192].tail()

from google.oauth2.credentials import Credentials
from google.oauth2 import service_account
from googleapiclient.discovery import build
from google.oauth2.service_account import Credentials

def load_google_sheet(spreadsheet_id, range_name, credentials_path):
    """
    Load data from a Google Sheet into a pandas DataFrame
    
    Args:
        spreadsheet_id (str): The ID of the Google Sheet (from the URL)
        range_name (str): The range to read (e.g. 'Sheet1!A1:D100')
        credentials_path (str): Path to the service account credentials JSON file
    
    Returns:
        pandas.DataFrame: DataFrame containing the sheet data
    """
    try:
        # Load credentials from service account file
        credentials = service_account.Credentials.from_service_account_file(
            credentials_path,
            scopes=['https://www.googleapis.com/auth/spreadsheets.readonly']
        )

        # Build the Sheets API service
        service = build('sheets', 'v4', credentials=credentials)
        
        # Call the Sheets API to get the data
        sheet = service.spreadsheets()
        result = sheet.values().get(
            spreadsheetId=spreadsheet_id,
            range=range_name
        ).execute()
        
        # Get the values from the result
        values = result.get('values', [])
        
        if not values:
            print('No data found in sheet')
            return pd.DataFrame()
            
        # Convert to DataFrame
        df = pd.DataFrame(values[1:], columns=values[0])
        
        return df
        
    except Exception as e:
        print(f"Error loading Google Sheet: {str(e)}")
        return pd.DataFrame()

# Example usage:
spreadsheet_hist_id = '1ayO8dl7sXaIYBAwkBGRUjbDms6MAbZFvvxxRp8IyxvY' # From Google Sheets URL
spreadsheet_2025_id = '1Qx6lv3zAL9YTsKJQNALa2GqBLXq0RER2lHvzyx32pRs' # From Google Sheets URL

range_name = 'Sheet1!A1:M'  
credentials_path = '../warn-analytics.json'

df_hist = load_google_sheet(spreadsheet_hist_id, range_name, credentials_path)  
df_2025 = load_google_sheet(spreadsheet_2025_id, range_name, credentials_path)

df_full = pd.concat([df_hist , df_2025], ignore_index=True)
df_full['WARN Received Date'] = pd.to_datetime(df_full['WARN Received Date'])
df_full['WARN Received Date'] = df_full['WARN Received Date'].dt.strftime('%Y-%m-%d')
df_full.sort_values('WARN Received Date', ascending=False, inplace=True)



# Company to Industry mapping

```mermaid
---
title: Process
---
flowchart TD
    A[Load WARN Dataset 
    through Google API client]
    B[Extract Company Names if 
    layoffs happened 
    after 2021]
    C[Use LLM to normalize 
    company names]
    D[Use LLM to determine 
    industry of normalized
    company name]
    
    A --> B
    B --> C
    C --> D
```