In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import json

import numpy as np
import ffmpeg
import whisper

from matplotlib import pyplot as plt
import matplotlib as mpl

In [None]:
model = whisper.load_model("base")

In [None]:
HOME_DIR = os.path.expanduser('~')
COBRA_DIR = os.path.join(HOME_DIR, 'cobra')
DATA_DIR = os.path.join(COBRA_DIR, 'data')
if not os.path.exists(DATA_DIR):
  os.makedirs(DATA_DIR)

In [None]:
audio_path = os.path.join(HOME_DIR, 'Downloads', 'fireside.mp3')
transcript_path = os.path.join(DATA_DIR, 'fireside.json')
output_path = os.path.join(DATA_DIR, 'fireside_fast.mp3')

In [None]:
with open(transcript_path, 'r') as f:
  result = json.load(f)

In [None]:
result = model.transcribe(audio_path)

In [None]:
with open(transcript_path, 'w') as f:
  json.dump(result, f)

In [None]:
segments = result['segments']
segments

In [None]:
sum(len(s['tokens']) for s in result['segments'])

In [None]:
def compute_speedups(segments):
  speedups = [1] * len(segments) # DEBUG
  return speedups

In [None]:
in_file = ffmpeg.input(audio_path)

In [None]:
segs = []
for data, speedup in zip(segments, speedups):
  seg = in_file.filter('atrim', start=data['start'], end=data['end']).filter('atempo', speedup)
  segs.append(seg)

In [None]:
cat = ffmpeg.concat(*segs, v=0, a=1)

In [None]:
cat.output(output_path).run()

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from queue import Queue

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
device = 'cpu'

In [None]:
seg_encodings = [tokenizer(seg['text'], return_tensors='pt') for seg in segments]

In [None]:
n_prev_segs = 10
q = Queue(maxsize=n_prev_segs)
seg_nlls = []
for i, encodings in enumerate(seg_encodings):
  input_ids = encodings.input_ids.to('cpu')
  ctxt_ids = torch.cat(list(q.queue) + [input_ids], axis=1)
  target_ids = ctxt_ids.clone()
  n = input_ids.shape[1]
  target_ids[:, :-n] = -100
  avg_nll = model(ctxt_ids, labels=target_ids).loss.detach().numpy()
  nll = avg_nll * n
  print(nll, i, len(seg_encodings))
  seg_nlls.append(nll)
  
  if q.full():
    q.get()
  q.put(input_ids)
seg_nlls = np.array(seg_nlls) / np.log(2)

In [None]:
seg_durations = [seg['end'] - seg['start'] for seg in segments]

In [None]:
info_densities = seg_nlls / seg_durations

In [None]:
times = [seg['start'] for seg in segments]

In [None]:
plt.xlabel('Time (seconds)')
plt.ylabel('Information density (bits per second)')
plt.plot(times, info_densities)
plt.show()