In [1]:
from transformers import pipeline
import torch
import pandas as pd
import numpy as np
import polars as pl
import json
import time

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cuda:1" if torch.cuda.is_available() else "cpu"


model_path = "facebook/bart-large-mnli"
classification_task = pipeline("zero-shot-classification", model=model_path,
                          tokenizer=model_path, device=device, batch_size=64)

print(device)

mps


In [3]:
sport_labels = ["football", "basketball", "tennis", "golf", "rugby", "cricket", "hockey", "baseball", "volleyball", "american football", "olympics"]

In [4]:
classification_task("federer is a good player", sport_labels)

{'sequence': 'federer is a good player',
 'labels': ['tennis',
  'olympics',
  'football',
  'golf',
  'rugby',
  'american football',
  'basketball',
  'volleyball',
  'baseball',
  'hockey',
  'cricket'],
 'scores': [0.27511724829673767,
  0.10333362966775894,
  0.09027232974767685,
  0.08981011807918549,
  0.08795522898435593,
  0.08006685227155685,
  0.07014228403568268,
  0.06668460369110107,
  0.0487465038895607,
  0.04826689511537552,
  0.039604317396879196]}

In [5]:
# Load the DataFrame using Polars
filtered_df_sport_category = pl.read_parquet(
    'filtered_sport_category_without_description_column_metadata.parquet')

In [6]:
len(filtered_df_sport_category)

4354412

In [7]:
BATCH_SIZE = 256
MAX_LENGTH = 512  # Maximum length for the model

def zero_shot_classification_in_batches(df, batch_size=BATCH_SIZE):
    total_rows = len(df)
    classifications = {label: [] for label in sport_labels}
    
    for i in range(0, total_rows, batch_size):
        end_index = min(i + batch_size, total_rows)
        # Extract the batch
        batch = df[i:end_index]
        
        texts = [json.dumps({"tags": row["tags"], "title": row["title"]}) for row in batch.to_dicts()]
        
        results = classification_task(texts, sport_labels, multi_label=True)
        
        for j, classification in enumerate(results):
            for z, label in enumerate(sport_labels):
                classifications[label].append(classification["scores"][z])

        print(f"Processed {min(i + batch_size, total_rows)}/{total_rows} rows...")
    
    for key in classifications:
        df = df.with_columns(pl.Series(f"{key}_classification", classifications[key]))
    
    return df

filtered_df_sport_category = zero_shot_classification_in_batches(filtered_df_sport_category)

Processed 256/4354412 rows...


KeyboardInterrupt: 

In [None]:
filtered_df_sport_category

categories,channel_id,crawl_date,dislike_count,display_id,duration,like_count,tags,title,upload_date,view_count
str,str,str,f64,str,i64,f64,str,str,str,f64
"""Sports""","""UCzWn_gTaXyH5Idyo8Raf7_A""","""2019-11-03 16:39:57.427254""",35.0,"""JOeSxtcNdHQ""",8620,1673.0,"""catfishing,fishing,fishing cha…","""Catching 100 lbs of Catfish 🔴L…","""2019-10-01 00:00:00""",48737.0
"""Sports""","""UCzWn_gTaXyH5Idyo8Raf7_A""","""2019-11-03 16:39:58.108323""",15.0,"""EPMLTw2zINw""",355,1297.0,"""""","""big cat""","""2019-10-01 00:00:00""",19999.0
"""Sports""","""UCzWn_gTaXyH5Idyo8Raf7_A""","""2019-11-03 16:39:58.773085""",78.0,"""Y1_pK68iSYQ""",603,3305.0,"""Catfishing,how to catch catfis…","""Classy Catfishing - How to Cat…","""2019-09-28 00:00:00""",58518.0
"""Sports""","""UCzWn_gTaXyH5Idyo8Raf7_A""","""2019-11-03 16:39:59.465346""",70.0,"""jF8TSo3ZfTc""",1426,1889.0,"""Fishing,Fishing uk,Angling,Sea…","""2 Day Saltwater Fishing Catch …","""2019-09-21 00:00:00""",71998.0
"""Sports""","""UCzWn_gTaXyH5Idyo8Raf7_A""","""2019-11-03 16:40:00.188768""",73.0,"""Gp00dNaVouo""",990,2699.0,"""Fishing,catfish,wels catfish,h…","""How to Catch Wels Catfish - Fi…","""2019-09-14 00:00:00""",101924.0
…,…,…,…,…,…,…,…,…,…,…
"""Sports""","""UCrwEMKhsjY8P9-GuIKMYVrQ""","""2019-11-17 22:39:14.232693""",7.0,"""Q9H_fk6uHDk""",1121,89.0,"""hypnosis,progressive hypnosis,…","""Play Better Golf Part 4 ★ Putt…","""2017-02-14 00:00:00""",20430.0
"""Sports""","""UCrwEMKhsjY8P9-GuIKMYVrQ""","""2019-11-17 22:39:14.843290""",13.0,"""3lwXzOboOzk""",1341,91.0,"""hypnosis,progressive hypnosis,…","""Play Better Golf Part 3 ★ Owni…","""2017-02-14 00:00:00""",25817.0
"""Sports""","""UCrwEMKhsjY8P9-GuIKMYVrQ""","""2019-11-17 22:39:15.484430""",19.0,"""242JzJuuG78""",1098,115.0,"""hypnosis,progressive hypnosis,…","""Play Better Golf Part 2 ★ Fair…","""2017-02-14 00:00:00""",29909.0
"""Sports""","""UCrwEMKhsjY8P9-GuIKMYVrQ""","""2019-11-17 22:39:16.111873""",28.0,"""CpMWSgoRwNI""",1245,372.0,"""hypnosis,progressive hypnosis,…","""Play Better Golf Part 1 ★ Gett…","""2017-02-14 00:00:00""",61980.0
