In [5]:
import pandas as pd
import torch
from transformers import pipeline
from classify import classify

# Candidate labels
purpose_labels = [
    "lecture or academic course",
    "hacks", 
    "conference",
    "tutorial or DIY",
    "interview or Q&A or review",
    "kids content",
    "entertaining explanation or science popularization",
    "documentary"
]

# Check for MPS support
if not torch.backends.mps.is_available():
    print("MPS backend is not available on this device.")
    exit()

# Set the device to MPS
device = torch.device('mps')

# Load data
data = pd.read_csv('data/Education_videos_7_cleaned.csv').head(10000)

# Initialize the BART-based classifier
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)

# Pre-process text
data['text'] = data['title'] + data['tags']

# Perform classification
print('Start classification...')
final_data = classify(data, candidate_labels=purpose_labels, on='text', classifier=classifier, batch_size=32, multi_label=True)

# Save results
final_data.to_csv('data/Education_videos_7_classified_1.csv', index=False)
print('CSV saved')


Start classification...
Converting to dataset...
Processing...


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Converting back to DataFrame...
CSV saved
