In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import json

import numpy as np
import ffmpeg
import whisper

from matplotlib import pyplot as plt
import matplotlib as mpl

In [None]:
whisper_model = whisper.load_model("base")

In [None]:
HOME_DIR = os.path.expanduser('~')
COBRA_DIR = os.path.join(HOME_DIR, 'cobra')
DATA_DIR = os.path.join(COBRA_DIR, 'data')
if not os.path.exists(DATA_DIR):
  os.makedirs(DATA_DIR)

In [None]:
# youtube-dl "https://www.youtube.com/watch?v=ESDeUi8Yl-8" --audio-format mp3 -x
audio_path = os.path.join(HOME_DIR, 'Downloads', 'fireside.mp3')

transcript_path = os.path.join(DATA_DIR, 'fireside.json')
output_path = os.path.join(DATA_DIR, 'fireside_smooth.mp3')

In [None]:
with open(transcript_path, 'r') as f:
  result = json.load(f)

In [None]:
result = whisper_model.transcribe(audio_path)

In [None]:
with open(transcript_path, 'w') as f:
  json.dump(result, f)

In [None]:
segments = result['segments']
len(segments), segments[0]

In [None]:
plt.xlabel('Duration (seconds)')
plt.ylabel('Number of segments')
seg_durations = [seg['end'] - seg['start'] for seg in segments]
plt.hist(seg_durations)
plt.show()

In [None]:
def compute_speedups(info_densities):
  avg_density = np.mean(info_densities)
  speedups = avg_density / info_densities
  return speedups

In [None]:
speedups = compute_speedups(info_densities)

In [None]:
in_file = ffmpeg.input(audio_path)

In [None]:
segs = []
for data, speedup in zip(segments, speedups):
  seg = in_file.filter('atrim', start=data['start'], end=data['end']).filter('atempo', speedup)
  segs.append(seg)

In [None]:
cat = ffmpeg.concat(*segs, v=0, a=1)

In [None]:
cat.output(output_path).run()

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from queue import Queue

In [None]:
llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
device = 'cpu'

In [None]:
def compute_info_densities(segments, verbose=False):
  seg_encodings = [tokenizer(seg['text'], return_tensors='pt') for seg in segments]
  input_ids = [enc.input_ids.to(device) for enc in seg_encodings]
  seg_lens = [x.shape[1] for x in input_ids]
  cat_input_ids = torch.cat(input_ids, axis=1)
  idx = 0
  seg_nlls = []
  for i, seg_len in enumerate(seg_lens):
    idx += seg_len
    ctxt_ids = cat_input_ids[:, :idx]
    target_ids = ctxt_ids.clone()
    target_ids[:, :-seg_len] = -100
    avg_nll = llm(ctxt_ids, labels=target_ids).loss.detach().numpy()
    nll = avg_nll * seg_len
    seg_nlls.append(nll)
    if verbose:
      print(nll, avg_nll, i, len(seg_lens))
  seg_nlls = np.array(seg_nlls) / np.log(2)
  info_densities = seg_nlls / seg_durations
  return info_densities

In [None]:
info_densities = compute_info_densities(segments, verbose=True)

In [None]:
times = np.array([seg['start'] for seg in segments])

In [None]:
def smooth(xs, win=10):
  win = min(len(xs), win)
  psums = np.concatenate((np.zeros(1), np.cumsum(xs)))
  rtn = (psums[win:] - psums[:-win]) / win
  rtn[0] = xs[0]
  return rtn

In [None]:
win = 1
plt.xlabel('Time (minutes)')
plt.ylabel('Information density (bits per second)')
plt.plot(times[win-1:]/60, smooth(info_densities, win=win))
plt.show()

In [None]:
sorted_seg_idxes = sorted(list(range(len(segments))), key=lambda i: info_densities[i])

In [None]:
for s in segments:
  if s['end'] < 600:
    print(s['text'])

In [None]:
[(i/len(segments), segments[i]['text']) for i in sorted_seg_idxes[:20]]

In [None]:
[(i/len(segments), segments[i]['text']) for i in sorted_seg_idxes[-20:]]

In [None]:
for s in segments:
  if 2000 < s['start'] and s['end'] < 3000:
    print(s['text'])