<a href="https://colab.research.google.com/github/nsmq-ai/nsmqai/blob/kojomensahonums-stt-inference-notebook-version_2/STT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import and install the required libraries for asr

%%capture
!pip install git+https://github.com/openai/whisper.git
!pip install jiwer
!pip install tabulate
!pip install pydub
!pip install transformers
import torch
import numpy as np
import whisper
import jiwer
import time
import pandas as pd
from tabulate import tabulate
from pydub import AudioSegment
import os
import joblib
import re
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F
from torch import nn, Tensor

In [None]:

# Install required libraries for web api
!pip  install fastapi
!pip -q install pyngrok
!pip -q install uvicorn
!pip -q install nest_asyncio
!pip -q install python-multipart
# nest_asyncio.apply()
#!pip uninstall fastapi typing-extensions -y
#!pip install fastapi typing-extensions==4.8.0 --no-cache-dir

In [None]:
# Import libraries
import uvicorn
#from typing_extensions import Annotated, Doc
from fastapi import FastAPI,Response
from fastapi.middleware.cors import CORSMiddleware
from pyngrok import ngrok
from pydantic import BaseModel
import nest_asyncio
import shutil

# Import models for serialisation/ deserialisation
from pydantic import BaseModel
import base64
import io
import wave

nest_asyncio.apply()

app = FastAPI()

# app.add_middleware(
#     CORSMiddleware,
#     allow_origins=['*'],
#     allow_credentials=True,
#     allow_methods=['*'],
#     allow_headers=['*'],
# )


### Miscellaneous functions

In [None]:
# Load whisper model
torch.cuda.is_available()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = whisper.load_model("medium.en", device = DEVICE) # Select whisper model size (tiny, base, small, medium, large)




In [None]:
def transcribe(path_to_audio):
  """Loads whisper model to transcribe audio"""

  # Load audio
  audio = whisper.load_audio(path_to_audio)

  # Transcribe audio
  result = model.transcribe(audio)

  # Print transcript
  return result["text"]

  ## Approach 2
  ### This uses a more user-controlled processing technique than the former
  # audio = whisper.load_audio(path_to_audio)
  # audio = whisper.pad_or_trim(audio)

  # # Make log-Mel spectrogram and move to the same device as the model
  # mel = whisper.log_mel_spectrogram(audio).to(model.device)

  # # Decode the audio
  # options = whisper.DecodingOptions(language= "en", without_timestamps= True, fp16 = False)
  # result = whisper.decode(model, mel, options)

  # return result.text

In [None]:
def detect_start_point(transcribed_text):
  """Detects start points/ riddle cues present in audio transcripts"""

  # Sample list of riddle start points
  sample_start_points = ["we begin", "i begin", "let's begin",\
                         "first riddle", "1st riddle", "riddle number one", "riddle number 1",\
                         "second riddle", "2nd riddle", "riddle number two", "riddle number 2",\
                         "third riddle", "3rd riddle", "riddle number three", "riddle number 3",\
                         "fourth riddle", "4th riddle", "riddle number four", "riddle number 4",\
                         "fifth riddle", "5th riddle", "riddle number five", "riddle number 5",\
                         "last riddle", "final riddle", "last one", "next one", "first one", \
                         "second one", "third one", "fourth one", "fifth one",\
                         "first redo", "second redo", "third redo", "fourth redo", "last redo",\
                         "final redo", "fifth redo", "fast riddle"
                         ]

  # Check for a matching start point
  matching_start_point = None
  for start_point in sample_start_points:
      if start_point in transcribed_text.lower():
          matching_start_point = start_point
          break

  return matching_start_point

In [None]:
def detect_end_point(transcribed_text):
  """Detects end points present in audio transcripts"""

  # Sample list of riddle start points
  end_points = ["who am i"]

  # Check for a matching start point
  matching_end_point = None
  for end_point in end_points:
      if end_point in transcribed_text.lower():
          matching_end_point = end_point
          break

  return matching_end_point

In [None]:
class BertClassifier(nn.Module):
  def __init__(self, pretrained_bert, num_classes):
    super(BertClassifier, self).__init__()
    self.model = pretrained_bert
    self.input_size = self.model.config.hidden_size
    # Fully connected classifier
    self.classifier = nn.Sequential(
        nn.Linear(self.input_size, 256),
        nn.ReLU(),
        nn.Linear(256, num_classes)
    )
  def forward(self, input_ids, attention_mask,labels=None):
      outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
      embedding = outputs.last_hidden_state.mean(dim=1)  #mean pooling
      logits = self.classifier(embedding)

      if labels is not None:
          loss_fn=nn.CrossEntropyLoss()
          loss=loss_fn(logits,labels)
          return loss
      return logits

def preprocess_bert_features(sentence,tokenizer):
  tokenized_input = tokenizer.encode_plus(sentence,padding='max_length',  max_length=512,truncation=True, return_tensors='pt', )
  return tokenized_input["input_ids"], tokenized_input["attention_mask"]

def predict_clue(sentence, model, tokenizer):
    # Preprocess the sentence
    input_ids, attention_mask = preprocess_bert_features(sentence, tokenizer)

    # Ensure tensors are on the same device as the model and perform inference
    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        predicted_class = torch.argmax(logits, dim=1).item()

    return predicted_class

tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-tiny")
pretrained_bert_model = BertModel.from_pretrained('prajjwal1/bert-tiny')
bert_model = BertClassifier(pretrained_bert=pretrained_bert_model, num_classes=2)
bert_model.load_state_dict(torch.load('speech-to-text/bert_classifier_model.pth')) # custom-trained classifier model
bert_model.eval()  # Set the model to evaluation mode

In [None]:
# Define global variables to store the accumulated clues and riddle clues
accumulated_clues = []
store_to_count = []

def process_audio_chunk(audio_filename):
  """Performs final piece audio transcription and riddle clue concatenation for the QA model"""

  # Initialize variables
  transcribed_text = " "  # To store the concatenated text
  previous_end_index = 0  # To keep track of the end index of the previous riddle
  clue_counter = 0 # count clues per new line
  is_new_riddle = False
  end_of_clues = False

  # Transcribe audio chunk
  chunk_transcript = transcribe(audio_filename)

  # Detect start point
  start_point = detect_start_point(chunk_transcript)

  # Detect end point
  end_point = detect_end_point(chunk_transcript)

  # If a matching start point is found, concatenate text
  if start_point:
    is_new_riddle = True
    accumulated_clues.clear() # clear already stored clues
    store_to_count.clear() # clear held count of riddle clues
    start_index = chunk_transcript.lower().find(start_point.lower()) # identify first position of start-point phrase
    previous_end_index = start_index + len(start_point) # set end position of start-point phrase

  if end_point:
    end_of_clues = True

  # Add the transcribed chunk to the continuous text
  transcribed_text = chunk_transcript[previous_end_index:].strip()

  # Process the text for riddle clues
  sentences = re.split(r'(?<=[.,?])', transcribed_text) # delimit transcribed text based on given punctuations
  clues_found = [] # store identified clues

  # Apply classifier on delimited texts in transcript
  for sentence in sentences:
    pred = predict_clue(sentence.strip(),bert_model,tokenizer) # applying classifier
    if pred == 1:
      accumulated_clues.append(sentence)  # Append the clue to the list
      clues_found.append(sentence)

  # Concatenate riddle clues
  if clues_found:
    grouped_clues = " ".join(clues_found)
    store_to_count.append(grouped_clues)
    for i in range(len(store_to_count)):
      clue_counter+=1
    return chunk_transcript, " ".join(accumulated_clues), clue_counter, is_new_riddle, end_of_clues #transcribed chunks, concatenated riddle clues, counter for clues, boolean if new riddle, boolean if riddle ends

  return chunk_transcript, " ", 0, is_new_riddle,  end_of_clues
  clue_counter.clear()

### For API Endpoint

In [None]:
class AudioBytes(BaseModel):
  data: bytes
  filename: str

@app.get("/get-transcript")
async def get_transcript(audio: AudioBytes):
  try:
    decoded_data = base64.b64decode(audio.data)

    # Write bytes data to a .wav file
    with io.BytesIO(decoded_data) as audio_file:
        with wave.open(audio_file, "wb") as wav:
          wav.setnchannels(1)
          wav.setsampwidth(2)
          wav.setframerate(16000)

          # Write .wav files
          wav.writeframes(decoded_data)

    # Save the audio file with the custom name
    audio_filename = audio.filename
    with open(audio_filename, "wb") as file:
        file.write(decoded_data)

    # Get transcript and delete temporary audio file
    chunk_transcript, current_clues, clue_counter, is_new_riddle, end_of_clues = process_audio_chunk(audio_filename) # current clues contains previous+recently identified clues
    os.remove(audio_filename)
    return {"transcript": chunk_transcript, "clues": current_clues, "clue_count":clue_counter, "is_start_of_riddle":is_new_riddle, "is_end_of_riddle":end_of_clues}
  except Exception as e:
    return {"error":str(e)}

@app.get("/stt-test")
async def stt_test():
  return {"transcript":"Hello from STT.", "clues":"", "clue_count":"", "is_start_of_riddle":"", "is_end_of_riddle":""}

In [None]:
# Attach personal token
#!ngrok config add-authtoken <place_your_ngrok_auth_token_here>
!ngrok config add-authtoken # place_your_ngrok_auth_token_here

In [None]:
# Create public url to access speech-to-text service
ngrok_tunnel = ngrok.connect(8000)
print("Public URL:", ngrok_tunnel.public_url)
uvicorn.run(app, port=8000)