In [1]:
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
import torch

if torch.cuda.is_available():
    device = 0  # NVIDIA GPU
    print("Using NVIDIA GPU")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = "mps"  # Mac GPU
    print("Using Mac GPU (MPS)")
else:
    device = -1  # CPU
    print("Using CPU")

print(f"Device: {device}")
print("Loading models...")

Using Mac GPU (MPS)
Device: mps
Loading models...


In [2]:
ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
zero_shot_classifier = pipeline(
    "zero-shot-classification", model="facebook/bart-large-mnli"
)

Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use mps:0
Device set to use mps:0


In [18]:
try:
    zero_shot_classifier = pipeline(
        "zero-shot-classification", model="facebook/bart-large-mnli", device=device
    )
    print("Models loaded successfully!")
except Exception as e:
    print(f"Error loading on {device}: {e}")
    print("Falling back to CPU...")
    device = -1
    zero_shot_classifier = pipeline(
        "zero-shot-classification", model="facebook/bart-large-mnli", device=device
    )

Device set to use mps


Models loaded successfully!


In [3]:
import pandas as pd

df = pd.read_parquet("shared-data-volume/cc/cc_merged.parquet")

In [4]:
df.columns

Index(['url', 'url_host_tld', 'url_host_registered_domain', 'fetch_status',
       'content_mime_detected', 'content_mime_type', 'warc_filename',
       'warc_record_offset', 'warc_record_length', 'cc_abn', 'company_name',
       'business_info', 'raw_text_body'],
      dtype='object')

In [13]:
def extract_company_info(text):
    """Extract industry info with quality filtering"""
    if not text or pd.isna(text) or not isinstance(text, str) or len(text.strip()) == 0:
        return {
            "industry_type": None,
            "industry_confidence": 0,
            "extraction_quality": "empty_text",
        }

    try:
        text_limited = text[:2000]

        # Extract industry
        industry_labels = [
            "Technology",
            "Healthcare",
            "Finance",
            "Retail",
            "Manufacturing",
            "Education",
            "Transportation",
            "Energy",
            "Real Estate",
            "Hospitality",
            "Media",
            "Telecommunications",
            "Agriculture",
            "Construction",
            "E-commerce",
        ]

        industry_result = zero_shot_classifier(
            text_limited, industry_labels, multi_label=False
        )
        industry_type = industry_result["labels"][0]
        industry_confidence = round(industry_result["scores"][0], 3)

        # Only keep if confidence > 0.5
        if industry_confidence < 0.5:
            industry_type = None
            extraction_quality = "low_confidence"
        else:
            extraction_quality = "good"

        return {
            "industry_type": industry_type,
            "industry_confidence": industry_confidence,
            "extraction_quality": extraction_quality,
        }

    except Exception as e:
        return {
            "industry_type": None,
            "industry_confidence": 0,
            "extraction_quality": f"error: {str(e)[:30]}",
        }

In [19]:
new_df = df[10:30]

In [20]:
results = []
for idx, row in enumerate(new_df.iterrows()):
    text = row[1]["raw_text_body"]
    extracted_info = extract_company_info(text)
    results.append(extracted_info)

In [21]:
results

[{'industry_type': None,
  'industry_confidence': 0.329,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.155,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.203,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.364,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.271,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.106,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.25,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.266,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.212,
  'extraction_quality': 'low_confidence'},
 {'industry_type': None,
  'industry_confidence': 0.208,
  'extraction_quality': 'low_confidence'},
 