Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 106 additions & 40 deletions apps/desktop/src/utils/segment/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { segmentationPass } from "./pass-build-segments";
import { mergeSegmentsPass } from "./pass-merge-segments";
import { normalizeWordsPass } from "./pass-normalize-words";
import { identityPropagationPass } from "./pass-propagate-identity";
import { resolveIdentitiesPass } from "./pass-resolve-speakers";
Expand All @@ -15,6 +14,7 @@ import type {
SegmentWord,
SpeakerIdentity,
SpeakerState,
StageId,
WordLike,
} from "./shared";

Expand Down Expand Up @@ -44,22 +44,71 @@ export function buildSegments<
}

const context = createSegmentPassContext(speakerHints, options);
const initialGraph: SegmentGraph = {
finalWords,
partialWords,
};

const graph = runSegmentPipeline(defaultSegmentPasses, initialGraph, context);
return finalizeSegments(graph.segments ?? []);
const initialGraph: SegmentGraph = { finalWords, partialWords };
const graph = runSegmentPipeline(initialGraph, context);
const segmentsGraph = ensureGraphKey(
graph,
"segments",
"Segment pipeline must produce segments",
);
return finalizeSegments(segmentsGraph.segments);
}

const defaultSegmentPasses: readonly SegmentPass[] = [
normalizeWordsPass,
resolveIdentitiesPass,
segmentationPass,
identityPropagationPass,
mergeSegmentsPass,
];
type SegmentPipelineStage<
TNeeds extends readonly (keyof SegmentGraph)[],
TEnsures extends keyof SegmentGraph,
> = {
pass: SegmentPass<TNeeds[number]>;
needs: TNeeds;
ensures: TEnsures;
errorMessage: string;
};

const SEGMENT_PIPELINE = [
{
pass: normalizeWordsPass,
needs: [] as const,
ensures: "words",
errorMessage: "normalizeWordsPass must produce words",
},
{
pass: resolveIdentitiesPass,
needs: ["words"] as const,
ensures: "frames",
errorMessage: "resolveIdentitiesPass must produce frames",
},
{
pass: segmentationPass,
needs: ["frames"] as const,
ensures: "segments",
errorMessage: "segmentationPass must produce segments",
},
{
pass: identityPropagationPass,
needs: ["segments"] as const,
ensures: "segments",
errorMessage: "identityPropagationPass must preserve segments",
},
] as const satisfies readonly SegmentPipelineStage<
readonly (keyof SegmentGraph)[],
keyof SegmentGraph
>[];

function runSegmentPipeline(
initialGraph: SegmentGraph,
ctx: SegmentPassContext,
): SegmentGraph {
return SEGMENT_PIPELINE.reduce<SegmentGraph>((graph, stage) => {
const ensuredGraph = ensureGraphHasKeys(graph, stage.needs, stage.pass.id);
return runPassAndExpectKey(
stage.pass,
ensuredGraph,
ctx,
stage.ensures,
stage.errorMessage,
);
}, initialGraph);
}

function createSpeakerState(
speakerHints: readonly RuntimeSpeakerHint[],
Expand Down Expand Up @@ -111,36 +160,53 @@ function createSegmentPassContext(
};
}

function ensurePassRequirements(pass: SegmentPass, graph: SegmentGraph) {
if (!pass.needs || pass.needs.length === 0) {
return;
}

const missing = pass.needs.filter((key) => graph[key] === undefined);
if (missing.length > 0) {
throw new Error(
`Segment pass "${pass.id}" missing required graph keys: ${missing.join(", ")}`,
);
}
}

function runSegmentPipeline(
passes: readonly SegmentPass[],
initialGraph: SegmentGraph,
ctx: SegmentPassContext,
): SegmentGraph {
return passes.reduce((graph, pass) => {
ensurePassRequirements(pass, graph);
return pass.run(graph, ctx);
}, initialGraph);
}

function finalizeSegments(segments: ProtoSegment[]): Segment[] {
return segments.map((segment) => ({
key: segment.key,
words: segment.words.map(({ word }) => {
const { order: _order, ...rest } = word;
const { order, ...rest } = word;
return rest as SegmentWord;
}),
}));
}

type GraphWithKey<TKey extends keyof SegmentGraph> = SegmentGraph & {
[K in TKey]-?: NonNullable<SegmentGraph[K]>;
};

function ensureGraphHasKeys<TKeys extends readonly (keyof SegmentGraph)[]>(
graph: SegmentGraph,
keys: TKeys,
stageId: StageId,
): GraphWithKey<TKeys[number]> {
const ensured = keys.reduce<SegmentGraph>((current, key) => {
return ensureGraphKey(current, key, `${stageId} requires ${String(key)}`);
}, graph);

return ensured as GraphWithKey<TKeys[number]>;
}

function ensureGraphKey<TKey extends keyof SegmentGraph>(
graph: SegmentGraph,
key: TKey,
errorMessage: string,
): GraphWithKey<TKey> {
if (graph[key] == null) {
throw new Error(errorMessage);
}
return graph as GraphWithKey<TKey>;
}

function runPassAndExpectKey<
TNeeds extends keyof SegmentGraph,
TEnsures extends keyof SegmentGraph,
>(
pass: SegmentPass<TNeeds>,
graph: GraphWithKey<TNeeds>,
ctx: SegmentPassContext,
key: TEnsures,
errorMessage: string,
): GraphWithKey<TEnsures> {
const next = pass.run(graph, ctx);
return ensureGraphKey(next, key, errorMessage);
}
Loading
Loading