In [1]:
from huggingface_hub import hf_hub_download
import webdataset as wds
from pymongo import MongoClient
from PIL import Image
import numpy as np
import io
import cv2
from skimage.metrics import structural_similarity
import json
from datetime import datetime
import os

# Configuration
MONGO_URI = "mongodb://localhost:27017"  # Change to your MongoDB URI
DB_NAME = "video_lectures"
COLLECTION_NAME = "frames"
REPO_ID = "aegean-ai/ai-lectures-spring-24"
TAR_FILENAME = "youtube_dataset.tar"
DOWNLOAD_DIR = "./data"
os.makedirs(DOWNLOAD_DIR, exist_ok=True)

class WebDatasetToMongoDB:
    def __init__(self, mongo_uri, db_name, collection_name):
        self.client = MongoClient(mongo_uri)
        self.db = self.client[db_name]
        self.collection = self.db[collection_name]
        self.prev_frame = None
        self.current_subtitle = None
        self.similarity_threshold = 0.9
        self._create_indexes()

    def _create_indexes(self):
        """Create indexes for efficient querying"""
        self.collection.create_index("timestamp")
        self.collection.create_index([("subtitle", "text")])
        self.collection.create_index("video_id")
        self.collection.create_index("frame_hash", unique=True)

    def _frame_to_hash(self, frame_array):
        """Generate hash for frame content"""
        gray = cv2.cvtColor(frame_array, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (16, 16))
        return str(resized.tobytes())

    def _is_similar_frame(self, frame1, frame2):
        """Check frame similarity using SSIM"""
        if frame1 is None or frame2 is None:
            return False
            
        gray1 = cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY)
        gray2 = cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY)
        gray1 = cv2.resize(gray1, (64, 64))
        gray2 = cv2.resize(gray2, (64, 64))
        score, _ = structural_similarity(gray1, gray2, full=True)
        return score > self.similarity_threshold

    def _process_sample(self, sample):
        """Process a WebDataset sample"""
        documents = []
        
        # Extract metadata
        metadata = json.loads(sample['metadata.json'].decode('utf-8')) if 'metadata.json' in sample else {}
        video_id = metadata.get('video_id', 'unknown')
        
        # Get current subtitle
        current_subtitle = sample.get('subtitle.txt', b'').decode('utf-8').strip()
        
        # Process all frames
        for key in [k for k in sample.keys() if k.startswith('frame_')]:
            try:
                frame_data = sample[key]
                frame = Image.open(io.BytesIO(frame_data))
                frame_array = np.array(frame)
                timestamp = float(key.split('_')[1].split('.')[0])
                frame_hash = self._frame_to_hash(frame_array)
                
                if not self._is_similar_frame(frame_array, self.prev_frame):
                    document = {
                        'video_id': video_id,
                        'timestamp': timestamp,
                        'frame_data': frame_data,
                        'frame_format': 'png',
                        'frame_width': frame.width,
                        'frame_height': frame.height,
                        'frame_hash': frame_hash,
                        'subtitle': current_subtitle,
                        'metadata': metadata,
                        'processing_date': datetime.utcnow()
                    }
                    documents.append(document)
                    self.prev_frame = frame_array
                    
            except Exception as e:
                print(f"Error processing frame {key}: {str(e)}")
                continue
                
        return documents

    def process_webdataset(self, tar_path, batch_size=100):
        """Process downloaded WebDataset"""
        dataset = wds.WebDataset(tar_path).decode(wds.autodecode.ImageHandler("rgb"))
        
        total_processed = 0
        batch = []
        
        for sample in dataset:
            documents = self._process_sample(sample)
            batch.extend(documents)
            
            if len(batch) >= batch_size:
                try:
                    result = self.collection.insert_many(batch, ordered=False)
                    total_processed += len(result.inserted_ids)
                    batch = []
                    print(f"Processed {total_processed} documents")
                except Exception as e:
                    print(f"Batch insert error: {str(e)}")
                    # Fallback to single inserts
                    for doc in batch:
                        try:
                            self.collection.insert_one(doc)
                            total_processed += 1
                        except:
                            continue
                    batch = []
        
        if batch:
            try:
                result = self.collection.insert_many(batch, ordered=False)
                total_processed += len(result.inserted_ids)
            except Exception as e:
                print(f"Final batch insert error: {str(e)}")
        
        print(f"Processing complete. Total documents inserted: {total_processed}")
        return total_processed

# Download the dataset
print("Downloading dataset...")
tar_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=TAR_FILENAME,
    repo_type="dataset",
    local_dir=DOWNLOAD_DIR,
    local_dir_use_symlinks=False
)
print(f"Dataset downloaded to: {tar_path}")

# Process and store in MongoDB
print("Processing dataset and storing in MongoDB...")
processor = WebDatasetToMongoDB(
    mongo_uri=MONGO_URI,
    db_name=DB_NAME,
    collection_name=COLLECTION_NAME
)

processor.process_webdataset(
    tar_path=tar_path,
    batch_size=200
)

print("All done!")

ModuleNotFoundError: No module named 'pymongo'