<a href="https://colab.research.google.com/github/oademid2/atila-core-service/blob/master/summarize_youtube_video.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Summarize a Youtube Video

In [None]:
!pip install transformers sentence_transformers pytube

# optional install pytorch so you can use a gpu for faster transcription
# command below is for Linux. See instructions for mac and windows: https://pytorch.org/get-started/locally/
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu

!pip install git+https://github.com/openai/whisper.git -q
!apt install ffmpeg # https://stackoverflow.com/questions/51856340/how-to-install-package-ffmpeg-in-google-colab

In [None]:

from typing import Dict

from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import pytube
import time


class EndpointHandler():
    # load the model
    WHISPER_MODEL_NAME = "tiny.en"
    SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
    QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __init__(self, path=""):

        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f'whisper and question_answer_model will use: {device}')

        t0 = time.time()
        self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
        t1 = time.time()

        total = t1 - t0
        print(f'Finished loading whisper_model in {total} seconds')

        t0 = time.time()
        self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME)
        t1 = time.time()

        total = t1 - t0
        print(f'Finished loading sentence_transformer_model in {total} seconds')
        
        self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
        t0 = time.time()
        self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME).to(device)
        t1 = time.time()
        total = t1 - t0
        print(f'Finished loading question_answer_model in {total} seconds')

    def __call__(self, data: Dict[str, str]) -> Dict:
        """
        Args:
            data (:obj:):
                includes the URL to video for transcription
        Return:
            A :obj:`dict`:. transcribed dict
        """
        # process input
        print('data', data)

        if "inputs" not in data:
            raise Exception(f"data is missing 'inputs' key which  EndpointHandler expects. Received: {data}"
                            f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
        video_url = data.pop("video_url", None)
        query = data.pop("query", None)
        long_form_answer = data.pop("long_form_answer", None)
        encoded_segments = {}
        if video_url:
            video_with_transcript = self.transcribe_video(video_url)
            video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}"
            encode_transcript = data.pop("encode_transcript", True)
            if encode_transcript:
                encoded_segments = self.combine_transcripts(video_with_transcript)
                encoded_segments = {
                    "encoded_segments": self.encode_sentences(encoded_segments)
                }
            return {
                **video_with_transcript,
                **encoded_segments
            }
        elif query:
            if long_form_answer:
                context = data.pop("context", None)
                answer = self.generate_answer(query, context)
                response = {
                    "answer": answer
                }

                return response
            else:
                query = [{"text": query, "id": ""}] if isinstance(query, str) else query
                encoded_segments = self.encode_sentences(query)

                response = {
                    "encoded_segments": encoded_segments
                }

                return response

        else:
            return {
                "error": "'video_url' or 'query' must be provided"
            }

    def summarize(self, video_url):
        pass