Skip to content

Commit

Permalink
feat: Annotation with new types
Browse files Browse the repository at this point in the history
  • Loading branch information
purzelrakete committed Feb 14, 2024
1 parent c26bc82 commit 4c83576
Show file tree
Hide file tree
Showing 14 changed files with 292 additions and 68 deletions.
18 changes: 12 additions & 6 deletions align.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import common


def piecewise_linear(transcript):
alignment = []
for s in transcript['segments']:
def piecewise_linear(transcription: common.Transcription):
alignment = common.Alignment()
for s in transcription.transcript['segments']:
text = s['text']
start = float(s['start'])
end = float(s['end'])
duration = end - start
words = text.strip().split(' ')
for i, w in enumerate(words):
alignment.append(start + i * duration / len(words))
for i, word in enumerate(words):
alignment.words.append(common.Segment(
word,
start + i * duration / len(words),
start + (i + 1) * duration / len(words),
1.0
))

transcript['alignment'] = alignment
transcription.alignment = alignment
68 changes: 43 additions & 25 deletions annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@
logging.basicConfig(level=logging.INFO)


@dataclass
class Turn:
speaker: str
start: float
end: float


@dataclass
class AnnotationProgress:
percent_done: int
annotations: typing.List[Turn] = None
annotations: typing.List[common.Turn] = None


class AnnotationError(Exception):
Expand All @@ -33,12 +26,24 @@ class AnnotationError(Exception):
annotation_image = (
Image
.debian_slim(python_version="3.10.8")
.pip_install("pyannote.audio")
.pip_install("pyannote.audio===3.1.1")
)


class Progress:
def __call__(
self,
step_name,
step_artifact,
file = None,
total: typing.Optional[int] = None,
completed: typing.Optional[int] = None):
pass


@stub.function(
gpu="A10G",
cpu=8.0,
container_idle_timeout=180,
image=annotation_image,
network_file_systems=common.nfs,
Expand All @@ -48,26 +53,39 @@ class AnnotationError(Exception):
],
)
def annotate(transcription_id):
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.audio import Pipeline
import torchaudio
import torch

t = common.db.select(transcription_id)
if not t:
raise AnnotationError(f"invalid id : {transcription_id}")

hf_token = os.getenv['HF_TOKEN']
hf_token = os.getenv('HF_TOKEN')
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=hf_token)

# send pipeline to GPU (when available)
pipeline.to(torch.device("cuda"))

# apply pretrained pipeline
diarization = pipeline(t.transcoded_file)

# print the result
turns = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
turns.append(Turn(speaker, turn.start, tun.end))
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")

return turns
use_auth_token=hf_token,
)

device = torch.device(common.get_device())
pipeline.to(device)
logger.info(f"pipeline loaded onto {device}")

with ProgressHook() as hook:
# load audio. https://github.com/m-bain/whisperX/issues/399
waveform, sample_rate = torchaudio.load(t.transcoded_file)
logger.info(f"loaded waveform {waveform.size()}")
diarization = pipeline({
"waveform": waveform,
"sample_rate": sample_rate,
"hook": hook
})

turns = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
turns.append(common.Turn(speaker, turn.start, turn.end))

return common.Diarization(
turns=turns
)
101 changes: 95 additions & 6 deletions common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
import inspect
from pathlib import Path
import contextlib
Expand Down Expand Up @@ -91,18 +91,97 @@ def from_probe(probe):
return Track(**tags)


@dataclass
class Turn:
"""
Diarization information
"""
# name of the speaker
speaker: str
# time in seconds
start: float
# time in seconds
end: float


@dataclass
class Diarization:
"""
Diarization of this transcript
"""

# speaker turns
turns: typing.List[Turn]

def from_dict(d):
return Diarization(**{
k: v for k, v in d.items()
if k in inspect.signature(Diarization).parameters
})


@dataclass
class Segment:
"""
A single alignment segment
"""
# token
label: str
# time in seconds
start: float
# time in seconds
end: float
# likelihood of this alignment
score: float


@dataclass
class Alignment:
"""
A single word level alignment
"""

# original word segments
words: typing.List[Segment] = field(default_factory=list)

def from_dict(d):
return Alignment(**{
k: v for k, v in d.items()
if k in inspect.signature(Alignment).parameters
})


@dataclass
class Transcription:
"""
A complete transcription and metadata of a processed Upload.
"""

# canonical id
transcription_id: str

# initial upload information
upload: UploadInfo

# track metadta
track: Track = None
transcoded: bool = False

# transcript
transcript: str = None

# diarization
diarization: typing.Optional[Diarization] = None

# alignment
alignment: typing.Optional[Alignment] = None

# is it already transcoded
transcoded: bool = False

# path to the original upload
path: str = None

# source language
language: str = None

@property
Expand Down Expand Up @@ -131,14 +210,24 @@ def from_dict(d: dict):
if 'track' in d and d['track']:
track = Track.from_dict(d['track'])

ui = UploadInfo()
upload_info = UploadInfo()
if 'upload' in d and d['upload']:
ui = UploadInfo.from_dict(d['upload'])
upload_info = UploadInfo.from_dict(d['upload'])

alignment = Alignment()
if 'alignment' in d and d['alignment']:
alignment = Alignment.from_dict(d['alignment'])

diarization = Diarization(turns=[])
if 'diarization' in d and d['diarization']:
diarization = Diarization.from_dict(d['diarization'])

return Transcription(
transcription_id=d['transcription_id'],
upload=ui,
track=track,
upload=upload_info,
alignment=alignment,
diarization=diarization,
transcoded=d.get('transcoded', False),
transcript=d.get('transcript', {}),
path=d.get('path'),
Expand Down Expand Up @@ -210,7 +299,7 @@ def dataclass_to_event(x):

def get_device():
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
return "cuda:0" if torch.cuda.is_available() else "cpu"


@contextlib.contextmanager
Expand Down
7 changes: 5 additions & 2 deletions frontend/src/Editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import downloadImg from '../download.svg';
import '../css/Editor.css';

// lib
import { WhisperResult, Track } from './lib.ts';
import { WhisperResult, Track, Turn } from './lib.ts';

// data
import { debouncedPutTitle } from './data.ts';
Expand All @@ -27,6 +27,7 @@ export default function Editor({
initialTranscriptionId,
transcript,
track,
turns = [],
// focus
focus,
setFocus,
Expand All @@ -38,6 +39,7 @@ export default function Editor({
} : {
transcript: WhisperResult | null,
track: Track | null,
turns?: Turn[],
focus: number,
setFocus: (f: number) => void,
initialTranscriptionId: string,
Expand All @@ -59,6 +61,7 @@ export default function Editor({
}

function targets(text: string) {
console.log(turns);
return text.trim().split(' ').map((word, index) => {
if (index === focus) {
return (
Expand Down Expand Up @@ -122,7 +125,7 @@ export default function Editor({
</div>
)}
<div className="contents" onClick={hClick}>
{targets(transcript ? transcript.text : '')}
{ transcript !== null && targets(transcript.text) }
</div>
</article>
</>
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/Upload.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ export default function Upload({
} else {
// pretend to upload. hashing will be fast
hashingProgress = () => {}; // disable hashing progress
setUploading(true);
showState(states.uploading);
}

// initialize or resume the upload based on local storage
Expand Down
Loading

0 comments on commit 4c83576

Please sign in to comment.