In [None]:
# ======================================
# 🔧 SETUP
# ======================================
!pip install torch torchvision scikit-learn pyyoutube --quiet

import os
import torch
import requests
import numpy as np
from PIL import Image
from pathlib import Path
from random import shuffle
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from pyyoutube import Api

# ======================================
# 🔑 CONFIGURATION
# ======================================
API_KEY = ""  # <-- Replace this with your actual API key

CHANNEL_NAMES = [
    "veritasium", "crashcourse", "MKBHD", "Vox", "GrahamStephan",
    "jacksepticeye", "chloeting", "BingingWithBabish", "CollegeHumor", "5minutecrafts"
]
THUMBNAIL_FILE = "thumbnails.pt"

IMAGE_SIZE = 224

USE_VGG = True
VGG_NAME = "vgg16"
VGG_FREEZE = True

BATCH_SIZE = 8
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
vgg_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

# ======================================
# 📥 DATA COLLECTION
# ======================================
class Thumbnail:
    def __init__(self, pil_image, label):
        self.image = pil_image
        self.label = label

def fetch_thumbnails():
    if Path(THUMBNAIL_FILE).exists():
        print("✔️ Using cached thumbnails.")
        return torch.load(THUMBNAIL_FILE)

    print("📥 Downloading thumbnails from YouTube...")
    api = Api(api_key=API_KEY)
    thumbnails = []
    for idx, name in enumerate(CHANNEL_NAMES):
        print(f"🔍 Searching for channel: {name}")
        response = api.search_by_keywords(q=name, search_type="channel", count=1)
        if not response.items:
            print(f"⚠️ Channel {name} not found; skipping.")
            continue
        channel_id = response.items[0].id.channelId
        channel_info = api.get_channel_info(channel_id=channel_id)
        uploads_playlist = channel_info.items[0].to_dict()['contentDetails']['relatedPlaylists']['uploads']
        playlist_items = api.get_playlist_items(playlist_id=uploads_playlist, count=50)
        for item in playlist_items.items:
            try:
                video_id = item.snippet.resourceId.videoId
                video = api.get_video_by_id(video_id=video_id).items[0].to_dict()
                thumbnail_url = video['snippet']['thumbnails']['medium']['url']
                img = Image.open(requests.get(thumbnail_url, stream=True).raw).convert("RGB")
                thumbnails.append(Thumbnail(img, idx))
            except Exception as e:
                print("⚠️ Failed to process thumbnail:", e)

    torch.save(thumbnails, THUMBNAIL_FILE)
    print(f"✅ Saved {len(thumbnails)} thumbnails.")
    return thumbnails

# ======================================
# 📦 DATASET
# ======================================
class ThumbnailDataset(Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        pil_img = self.data[idx].image
        label = self.data[idx].label
        img_tensor = self.transform(pil_img)
        return img_tensor, label

# ======================================
# 🧠 VGG MODEL
# ======================================
from torchvision.models import vgg16, vgg19

def build_vgg(num_classes, vgg_name="vgg16", freeze_backbone=True):
    if vgg_name == "vgg16":
        model = vgg16(pretrained=True)
    elif vgg_name == "vgg19":
        model = vgg19(pretrained=True)
    else:
        raise ValueError("Unsupported vgg_name")

    if freeze_backbone:
        for param in model.features.parameters():
            param.requires_grad = False

    in_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(in_features, num_classes)

    return model

# ======================================
# 🏋️‍♀️ TRAINING
# ======================================
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs):
    model.to(DEVICE)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"📚 Epoch {epoch+1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}")
        evaluate(model, test_loader)

def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            outputs = model(x)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(y.numpy())
    print(classification_report(all_labels, all_preds, target_names=CHANNEL_NAMES))

# ======================================
# 🚀 MAIN
# ======================================
def main():
    thumbnails = fetch_thumbnails()
    shuffle(thumbnails)

    transform = vgg_transform
    model = build_vgg(num_classes=len(CHANNEL_NAMES), vgg_name=VGG_NAME, freeze_backbone=VGG_FREEZE)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    train_data, test_data = train_test_split(thumbnails, test_size=0.2, random_state=42)
    train_loader = DataLoader(ThumbnailDataset(train_data, transform), batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(ThumbnailDataset(test_data, transform), batch_size=BATCH_SIZE)

    train_model(model, train_loader, test_loader, optimizer, criterion, epochs=EPOCHS)

main()
