In [None]:
import gc
import json
import polars as pl
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
video_country = pl.read_csv('.././data/video_with_channelcountry.csv').filter(pl.col("country").is_not_null())

In [None]:
with open('.././data/stem_tags.json', 'r') as file:
    keywords = json.load(file)
tags = list(keywords['tags'])

In [None]:
def classify_stem_videos(
    df: pl.DataFrame, 
    stem_tags: list[str], 
    tag_column: str = "tags", 
    threshold: float = 0.5,
    batch_size: int = 10000
) -> pl.DataFrame:
    
    # Convert stem_tags to a set for O(1) lookup and store lowercase versions
    stem_tags_set: Set[str] = {tag.lower().strip() for tag in stem_tags if tag is not None}
    
    def process_batch(batch_df: pl.DataFrame) -> pl.Series:
        def calculate_stem_percentage(tags_str: str) -> bool:
            if not isinstance(tags_str, str):
                return False
            
            try:
                video_tags = [
                    tag.strip().lower() 
                    for tag in tags_str.split(",") 
                    if tag and tag.strip()
                ]
                
                if not video_tags:
                    return False

                stem_count = sum(1 for tag in video_tags if tag in stem_tags_set)
                return (stem_count / len(video_tags)) >= threshold
                
            except (AttributeError, TypeError, ValueError) as e:
                print(f"Error processing tags: {tags_str}, Error: {e}")
                return False
    
        return pl.Series([calculate_stem_percentage(tags) for tags in batch_df[tag_column]])
    
    total_rows = df.height
    num_batches = (total_rows + batch_size - 1) // batch_size

    results = []
    
    with tqdm(total=total_rows, desc="Classifying videos") as pbar:
        for i in range(0, total_rows, batch_size):
            try:
                batch = df.slice(i, min(batch_size, total_rows - i))
                batch_results = process_batch(batch)
                results.append(batch_results)
                pbar.update(batch.height)
                gc.collect()
                
            except Exception as e:
                print(f"Error processing batch starting at index {i}: {e}")
                continue

    try:
        is_stem_column = pl.concat(results, rechunk=True)
        return df.with_columns([
            is_stem_column.alias("is_stem")
        ])
    except Exception as e:
        print(f"Error combining results: {e}")
        return df.with_columns([
            pl.lit(False).alias("is_stem")
        ])

In [None]:
result_df = classify_stem_videos(video_country, tags)

In [None]:
stem_videos = result_df.filter(result_df['is_stem'] == True)

In [None]:
stem_videos.write_csv('stem_videos.csv')