diff --git a/apps/desktop/src/utils/segment/index.ts b/apps/desktop/src/utils/segment/index.ts index a50561ee8..f22e2ad2b 100644 --- a/apps/desktop/src/utils/segment/index.ts +++ b/apps/desktop/src/utils/segment/index.ts @@ -1,5 +1,4 @@ import { segmentationPass } from "./pass-build-segments"; -import { mergeSegmentsPass } from "./pass-merge-segments"; import { normalizeWordsPass } from "./pass-normalize-words"; import { identityPropagationPass } from "./pass-propagate-identity"; import { resolveIdentitiesPass } from "./pass-resolve-speakers"; @@ -15,6 +14,7 @@ import type { SegmentWord, SpeakerIdentity, SpeakerState, + StageId, WordLike, } from "./shared"; @@ -44,22 +44,71 @@ export function buildSegments< } const context = createSegmentPassContext(speakerHints, options); - const initialGraph: SegmentGraph = { - finalWords, - partialWords, - }; - - const graph = runSegmentPipeline(defaultSegmentPasses, initialGraph, context); - return finalizeSegments(graph.segments ?? []); + const initialGraph: SegmentGraph = { finalWords, partialWords }; + const graph = runSegmentPipeline(initialGraph, context); + const segmentsGraph = ensureGraphKey( + graph, + "segments", + "Segment pipeline must produce segments", + ); + return finalizeSegments(segmentsGraph.segments); } -const defaultSegmentPasses: readonly SegmentPass[] = [ - normalizeWordsPass, - resolveIdentitiesPass, - segmentationPass, - identityPropagationPass, - mergeSegmentsPass, -]; +type SegmentPipelineStage< + TNeeds extends readonly (keyof SegmentGraph)[], + TEnsures extends keyof SegmentGraph, +> = { + pass: SegmentPass; + needs: TNeeds; + ensures: TEnsures; + errorMessage: string; +}; + +const SEGMENT_PIPELINE = [ + { + pass: normalizeWordsPass, + needs: [] as const, + ensures: "words", + errorMessage: "normalizeWordsPass must produce words", + }, + { + pass: resolveIdentitiesPass, + needs: ["words"] as const, + ensures: "frames", + errorMessage: "resolveIdentitiesPass must produce frames", + }, + { + pass: segmentationPass, + needs: ["frames"] as const, + ensures: "segments", + errorMessage: "segmentationPass must produce segments", + }, + { + pass: identityPropagationPass, + needs: ["segments"] as const, + ensures: "segments", + errorMessage: "identityPropagationPass must preserve segments", + }, +] as const satisfies readonly SegmentPipelineStage< + readonly (keyof SegmentGraph)[], + keyof SegmentGraph +>[]; + +function runSegmentPipeline( + initialGraph: SegmentGraph, + ctx: SegmentPassContext, +): SegmentGraph { + return SEGMENT_PIPELINE.reduce((graph, stage) => { + const ensuredGraph = ensureGraphHasKeys(graph, stage.needs, stage.pass.id); + return runPassAndExpectKey( + stage.pass, + ensuredGraph, + ctx, + stage.ensures, + stage.errorMessage, + ); + }, initialGraph); +} function createSpeakerState( speakerHints: readonly RuntimeSpeakerHint[], @@ -111,36 +160,53 @@ function createSegmentPassContext( }; } -function ensurePassRequirements(pass: SegmentPass, graph: SegmentGraph) { - if (!pass.needs || pass.needs.length === 0) { - return; - } - - const missing = pass.needs.filter((key) => graph[key] === undefined); - if (missing.length > 0) { - throw new Error( - `Segment pass "${pass.id}" missing required graph keys: ${missing.join(", ")}`, - ); - } -} - -function runSegmentPipeline( - passes: readonly SegmentPass[], - initialGraph: SegmentGraph, - ctx: SegmentPassContext, -): SegmentGraph { - return passes.reduce((graph, pass) => { - ensurePassRequirements(pass, graph); - return pass.run(graph, ctx); - }, initialGraph); -} - function finalizeSegments(segments: ProtoSegment[]): Segment[] { return segments.map((segment) => ({ key: segment.key, words: segment.words.map(({ word }) => { - const { order: _order, ...rest } = word; + const { order, ...rest } = word; return rest as SegmentWord; }), })); } + +type GraphWithKey = SegmentGraph & { + [K in TKey]-?: NonNullable; +}; + +function ensureGraphHasKeys( + graph: SegmentGraph, + keys: TKeys, + stageId: StageId, +): GraphWithKey { + const ensured = keys.reduce((current, key) => { + return ensureGraphKey(current, key, `${stageId} requires ${String(key)}`); + }, graph); + + return ensured as GraphWithKey; +} + +function ensureGraphKey( + graph: SegmentGraph, + key: TKey, + errorMessage: string, +): GraphWithKey { + if (graph[key] == null) { + throw new Error(errorMessage); + } + return graph as GraphWithKey; +} + +function runPassAndExpectKey< + TNeeds extends keyof SegmentGraph, + TEnsures extends keyof SegmentGraph, +>( + pass: SegmentPass, + graph: GraphWithKey, + ctx: SegmentPassContext, + key: TEnsures, + errorMessage: string, +): GraphWithKey { + const next = pass.run(graph, ctx); + return ensureGraphKey(next, key, errorMessage); +} diff --git a/apps/desktop/src/utils/segment/pass-build-segments.ts b/apps/desktop/src/utils/segment/pass-build-segments.ts index 9394a4312..33cd12552 100644 --- a/apps/desktop/src/utils/segment/pass-build-segments.ts +++ b/apps/desktop/src/utils/segment/pass-build-segments.ts @@ -7,24 +7,12 @@ import type { SegmentPass, SpeakerIdentity, } from "./shared"; -import { SegmentKey as SegmentKeyModule } from "./shared"; +import { SegmentKey as SegmentKeyUtils } from "./shared"; -export const segmentationPass: SegmentPass = { +export const segmentationPass: SegmentPass<"frames"> = { id: "build_segments", - needs: ["frames"], run(graph, ctx) { - const frames = graph.frames ?? []; - const segments: ProtoSegment[] = []; - const activeSegments = new Map(); - - frames.forEach((frame) => { - const key = createSegmentKeyFromIdentity( - frame.word.channel, - frame.identity, - ); - placeFrameInSegment(frame, key, segments, activeSegments, ctx.options); - }); - + const segments = collectSegments(graph.frames, ctx.options); return { ...graph, segments }; }, }; @@ -47,46 +35,124 @@ function createSegmentKeyFromIdentity( params.speaker_human_id = identity.human_id; } - return SegmentKeyModule.make(params); + return SegmentKeyUtils.make(params); } -function hasSpeakerIdentity(key: SegmentKey): boolean { - return key.speaker_index !== undefined || key.speaker_human_id !== undefined; +type ChannelSegmentsState = { + activeByKey: Map; + lastAnonymous?: ProtoSegment; +}; + +type SegmentationReducerState = { + segments: ProtoSegment[]; + channelState: Map; +}; + +function collectSegments( + frames: ResolvedWordFrame[], + options?: SegmentBuilderOptions, +): ProtoSegment[] { + const initial: SegmentationReducerState = { + segments: [], + channelState: new Map(), + }; + + const finalState = frames.reduce( + (state, frame) => reduceFrame(state, frame, options), + initial, + ); + + return finalState.segments; } -function sameKey(a: SegmentKey, b: SegmentKey): boolean { - return ( - a.channel === b.channel && - a.speaker_index === b.speaker_index && - a.speaker_human_id === b.speaker_human_id +function reduceFrame( + state: SegmentationReducerState, + frame: ResolvedWordFrame, + options?: SegmentBuilderOptions, +): SegmentationReducerState { + const key = createSegmentKeyFromIdentity(frame.word.channel, frame.identity); + const channelState = channelStateFor(state.channelState, key.channel); + const extension = selectSegmentExtension( + state, + channelState, + key, + frame, + options, ); + + if (extension) { + extension.segment.words.push(frame); + channelState.activeByKey.set( + SegmentKeyUtils.serialize(extension.segment.key), + extension.segment, + ); + trackAnonymousSegment(channelState, extension.segment); + return state; + } + + const segment = startSegment(state.segments, key, frame); + channelState.activeByKey.set(SegmentKeyUtils.serialize(key), segment); + trackAnonymousSegment(channelState, segment); + return state; +} + +function selectSegmentExtension( + state: SegmentationReducerState, + channelState: ChannelSegmentsState, + key: SegmentKey, + frame: ResolvedWordFrame, + options?: SegmentBuilderOptions, +): { segment: ProtoSegment } | undefined { + const segmentId = SegmentKeyUtils.serialize(key); + const activeSegment = channelState.activeByKey.get(segmentId); + + if (activeSegment && canExtend(state, activeSegment, key, frame, options)) { + return { segment: activeSegment }; + } + + const anonymousSegment = channelState.lastAnonymous; + if ( + !SegmentKeyUtils.hasSpeakerIdentity(key) && + frame.word.isFinal && + anonymousSegment && + canExtend(state, anonymousSegment, anonymousSegment.key, frame, options) + ) { + return { segment: anonymousSegment }; + } + + return undefined; } -function segmentKeyId(key: SegmentKey): string { - return JSON.stringify([ - key.channel, - key.speaker_index ?? null, - key.speaker_human_id ?? null, - ]); +function startSegment( + segments: ProtoSegment[], + key: SegmentKey, + frame: ResolvedWordFrame, +): ProtoSegment { + const segment: ProtoSegment = { key, words: [frame] }; + segments.push(segment); + return segment; } -function canExtendSegment( +function canExtend( + state: SegmentationReducerState, existingSegment: ProtoSegment, candidateKey: SegmentKey, frame: ResolvedWordFrame, - segments: ProtoSegment[], options?: SegmentBuilderOptions, ): boolean { - if (hasSpeakerIdentity(candidateKey)) { - const lastSegment = segments[segments.length - 1]; - if (!lastSegment || !sameKey(lastSegment.key, candidateKey)) { + if (SegmentKeyUtils.hasSpeakerIdentity(candidateKey)) { + const lastSegment = state.segments[state.segments.length - 1]; + if ( + !lastSegment || + !SegmentKeyUtils.equals(lastSegment.key, candidateKey) + ) { return false; } } if ( !frame.word.isFinal && - existingSegment !== segments[segments.length - 1] + existingSegment !== state.segments[state.segments.length - 1] ) { const allWordsArePartial = existingSegment.words.every( (w) => !w.word.isFinal, @@ -101,38 +167,27 @@ function canExtendSegment( return frame.word.start_ms - lastWord.end_ms <= maxGapMs; } -function placeFrameInSegment( - frame: ResolvedWordFrame, - key: SegmentKey, - segments: ProtoSegment[], - activeSegments: Map, - options?: SegmentBuilderOptions, -): void { - const segmentId = segmentKeyId(key); - const existing = activeSegments.get(segmentId); - - if (existing && canExtendSegment(existing, key, frame, segments, options)) { - existing.words.push(frame); - return; +function channelStateFor( + channelState: Map, + channel: ChannelProfile, +): ChannelSegmentsState { + const existing = channelState.get(channel); + if (existing) { + return existing; } - if (frame.word.isFinal && !hasSpeakerIdentity(key)) { - for (const [id, segment] of activeSegments) { - if ( - !hasSpeakerIdentity(segment.key) && - segment.key.channel === key.channel - ) { - if (canExtendSegment(segment, segment.key, frame, segments, options)) { - segment.words.push(frame); - activeSegments.set(segmentId, segment); - activeSegments.set(id, segment); - return; - } - } - } - } + const state: ChannelSegmentsState = { + activeByKey: new Map(), + }; + channelState.set(channel, state); + return state; +} - const newSegment: ProtoSegment = { key, words: [frame] }; - segments.push(newSegment); - activeSegments.set(segmentId, newSegment); +function trackAnonymousSegment( + state: ChannelSegmentsState, + segment: ProtoSegment, +): void { + if (!SegmentKeyUtils.hasSpeakerIdentity(segment.key)) { + state.lastAnonymous = segment; + } } diff --git a/apps/desktop/src/utils/segment/pass-merge-segments.ts b/apps/desktop/src/utils/segment/pass-merge-segments.ts deleted file mode 100644 index 7f843756c..000000000 --- a/apps/desktop/src/utils/segment/pass-merge-segments.ts +++ /dev/null @@ -1,58 +0,0 @@ -import type { ProtoSegment, SegmentKey, SegmentPass } from "./shared"; - -export const mergeSegmentsPass: SegmentPass = { - id: "merge_segments", - needs: ["segments"], - run(graph) { - if (!graph.segments) { - return graph; - } - - return { ...graph, segments: mergeAdjacentSegments(graph.segments) }; - }, -}; - -function hasSpeakerIdentity(key: SegmentKey): boolean { - return key.speaker_index !== undefined || key.speaker_human_id !== undefined; -} - -function sameKey(a: SegmentKey, b: SegmentKey): boolean { - return ( - a.channel === b.channel && - a.speaker_index === b.speaker_index && - a.speaker_human_id === b.speaker_human_id - ); -} - -function canMergeSegments(seg1: ProtoSegment, seg2: ProtoSegment): boolean { - if (!hasSpeakerIdentity(seg1.key) && !hasSpeakerIdentity(seg2.key)) { - return false; - } - - return true; -} - -function mergeAdjacentSegments(segments: ProtoSegment[]): ProtoSegment[] { - if (segments.length <= 1) { - return segments; - } - - const merged: ProtoSegment[] = []; - - segments.forEach((segment) => { - const last = merged[merged.length - 1]; - - if ( - last && - sameKey(last.key, segment.key) && - canMergeSegments(last, segment) - ) { - last.words.push(...segment.words); - return; - } - - merged.push(segment); - }); - - return merged; -} diff --git a/apps/desktop/src/utils/segment/pass-normalize-words.ts b/apps/desktop/src/utils/segment/pass-normalize-words.ts index ef3652653..01ae64ae6 100644 --- a/apps/desktop/src/utils/segment/pass-normalize-words.ts +++ b/apps/desktop/src/utils/segment/pass-normalize-words.ts @@ -1,35 +1,5 @@ import type { SegmentPass, SegmentWord, WordLike } from "./shared"; -export function normalizeWords< - TFinal extends WordLike, - TPartial extends WordLike, ->( - finalWords: readonly TFinal[], - partialWords: readonly TPartial[], -): SegmentWord[] { - const finalNormalized = finalWords.map((word) => ({ - text: word.text, - start_ms: word.start_ms, - end_ms: word.end_ms, - channel: word.channel, - isFinal: true, - ...("id" in word && word.id ? { id: word.id as string } : {}), - })); - - const partialNormalized = partialWords.map((word) => ({ - text: word.text, - start_ms: word.start_ms, - end_ms: word.end_ms, - channel: word.channel, - isFinal: false, - ...("id" in word && word.id ? { id: word.id as string } : {}), - })); - - return [...finalNormalized, ...partialNormalized].sort( - (a, b) => a.start_ms - b.start_ms, - ); -} - export const normalizeWordsPass: SegmentPass = { id: "normalize_words", run(graph) { @@ -41,3 +11,31 @@ export const normalizeWordsPass: SegmentPass = { return { ...graph, words: normalized }; }, }; + +function normalizeWords( + finalWords: readonly TFinal[], + partialWords: readonly TPartial[], +): SegmentWord[] { + const normalized = [ + ...finalWords.map((word) => toSegmentWord(word, true)), + ...partialWords.map((word) => toSegmentWord(word, false)), + ]; + + return normalized.sort((a, b) => a.start_ms - b.start_ms); +} + +const toSegmentWord = (word: WordLike, isFinal: boolean): SegmentWord => { + const normalized: SegmentWord = { + text: word.text, + start_ms: word.start_ms, + end_ms: word.end_ms, + channel: word.channel, + isFinal, + }; + + if ("id" in word && word.id) { + normalized.id = word.id as string; + } + + return normalized; +}; diff --git a/apps/desktop/src/utils/segment/pass-propagate-identity.ts b/apps/desktop/src/utils/segment/pass-propagate-identity.ts index 65c66a873..5f0498a6d 100644 --- a/apps/desktop/src/utils/segment/pass-propagate-identity.ts +++ b/apps/desktop/src/utils/segment/pass-propagate-identity.ts @@ -4,58 +4,73 @@ import type { SegmentPass, SpeakerState, } from "./shared"; -import { SegmentKey as SegmentKeyModule } from "./shared"; +import { SegmentKey as SegmentKeyUtils } from "./shared"; -export function propagateCompleteChannelIdentities( +export const identityPropagationPass: SegmentPass<"segments"> = { + id: "propagate_identity", + run(graph, ctx) { + postProcessSegments(graph.segments, ctx.speakerState); + return { ...graph, segments: graph.segments }; + }, +}; + +function postProcessSegments( segments: ProtoSegment[], state: SpeakerState, ): void { - state.completeChannels.forEach((channel) => { - const humanId = state.humanIdByChannel.get(channel); - if (!humanId) { - return; + let writeIndex = 0; + let lastKept: ProtoSegment | undefined; + + for (const segment of segments) { + assignCompleteChannelHumanId(segment, state); + + if ( + lastKept && + SegmentKeyUtils.equals(lastKept.key, segment.key) && + SegmentKeyUtils.hasSpeakerIdentity(segment.key) + ) { + lastKept.words.push(...segment.words); + continue; } - segments.forEach((segment) => { - if ( - segment.key.channel !== channel || - segment.key.speaker_human_id !== undefined - ) { - return; - } - - const params: { - channel: ChannelProfile; - speaker_index?: number; - speaker_human_id: string; - } = { - channel, - speaker_human_id: humanId, - }; - - if (segment.key.speaker_index !== undefined) { - params.speaker_index = segment.key.speaker_index; - } - - segment.key = SegmentKeyModule.make(params); - }); - }); + segments[writeIndex] = segment; + lastKept = segment; + writeIndex += 1; + } + + segments.length = writeIndex; } -export const identityPropagationPass: SegmentPass = { - id: "propagate_identity", - needs: ["segments"], - run(graph, ctx) { - if (!graph.segments) { - return graph; - } +function assignCompleteChannelHumanId( + segment: ProtoSegment, + state: SpeakerState, +): void { + if (segment.key.speaker_human_id !== undefined) { + return; + } - const segments = graph.segments.map((segment) => ({ - ...segment, - words: [...segment.words], - })); + const channel = segment.key.channel; + if (!state.completeChannels.has(channel)) { + return; + } - propagateCompleteChannelIdentities(segments, ctx.speakerState); - return { ...graph, segments }; - }, -}; + const humanId = state.humanIdByChannel.get(channel); + if (!humanId) { + return; + } + + const params: { + channel: ChannelProfile; + speaker_index?: number; + speaker_human_id: string; + } = { + channel, + speaker_human_id: humanId, + }; + + if (segment.key.speaker_index !== undefined) { + params.speaker_index = segment.key.speaker_index; + } + + segment.key = SegmentKeyUtils.make(params); +} diff --git a/apps/desktop/src/utils/segment/pass-resolve-speakers.ts b/apps/desktop/src/utils/segment/pass-resolve-speakers.ts index 01e9c6fc0..ff707022f 100644 --- a/apps/desktop/src/utils/segment/pass-resolve-speakers.ts +++ b/apps/desktop/src/utils/segment/pass-resolve-speakers.ts @@ -1,73 +1,72 @@ import type { - IdentityProvenance, SegmentPass, SegmentWord, SpeakerIdentity, - SpeakerIdentityResolution, SpeakerState, } from "./shared"; -export function resolveSpeakerIdentity( - word: SegmentWord, - assignment: SpeakerIdentity | undefined, - state: SpeakerState, -): SpeakerIdentityResolution { - const provenance: IdentityProvenance[] = []; - const identity: SpeakerIdentity = {}; +type SpeakerStateSnapshot = Pick< + SpeakerState, + | "completeChannels" + | "humanIdByChannel" + | "humanIdBySpeakerIndex" + | "lastSpeakerByChannel" +>; - if (assignment) { - if (assignment.speaker_index !== undefined) { - identity.speaker_index = assignment.speaker_index; - } - if (assignment.human_id !== undefined) { - identity.human_id = assignment.human_id; - } - provenance.push("explicit_assignment"); - } +type IdentityRuleArgs = { + assignment?: SpeakerIdentity; + snapshot: SpeakerStateSnapshot; + word: SegmentWord; +}; - if (identity.speaker_index !== undefined && identity.human_id === undefined) { - const humanId = state.humanIdBySpeakerIndex.get(identity.speaker_index); - if (humanId !== undefined) { - identity.human_id = humanId; - provenance.push("speaker_index_lookup"); - } - } +type IdentityRule = ( + identity: SpeakerIdentity, + args: IdentityRuleArgs, +) => SpeakerIdentity; - if ( - identity.human_id === undefined && - state.completeChannels.has(word.channel) - ) { - const channelHumanId = state.humanIdByChannel.get(word.channel); - if (channelHumanId !== undefined) { - identity.human_id = channelHumanId; - provenance.push("channel_completion"); - } - } +export const resolveIdentitiesPass: SegmentPass<"words"> = { + id: "resolve_speakers", + run(graph, ctx) { + const frames = graph.words.map((word, index) => { + const assignment = ctx.speakerState.assignmentByWordIndex.get(index); + const identity = applyIdentityRules(word, assignment, ctx.speakerState); + rememberIdentity(word, assignment, identity, ctx.speakerState); - if ( - !word.isFinal && - (identity.speaker_index === undefined || identity.human_id === undefined) - ) { - const last = state.lastSpeakerByChannel.get(word.channel); - if (last) { - if ( - identity.speaker_index === undefined && - last.speaker_index !== undefined - ) { - identity.speaker_index = last.speaker_index; - provenance.push("last_speaker"); - } - if (identity.human_id === undefined && last.human_id !== undefined) { - identity.human_id = last.human_id; - provenance.push("last_speaker"); - } - } - } + return { + word, + identity, + }; + }); + + return { ...graph, frames }; + }, +}; + +function applyIdentityRules( + word: SegmentWord, + assignment: SpeakerIdentity | undefined, + snapshot: SpeakerStateSnapshot, +): SpeakerIdentity { + const rules: IdentityRule[] = [ + applyExplicitAssignment, + applySpeakerIndexHumanId, + applyChannelHumanId, + carryPartialIdentityForward, + ]; - return { identity, provenance }; + const args: IdentityRuleArgs = { + assignment, + snapshot, + word, + }; + + return rules.reduce( + (identity, rule) => rule(identity, args), + {} as SpeakerIdentity, + ); } -export function rememberIdentity( +function rememberIdentity( word: SegmentWord, assignment: SpeakerIdentity | undefined, identity: SpeakerIdentity, @@ -104,27 +103,82 @@ export function rememberIdentity( } } -export const resolveIdentitiesPass: SegmentPass = { - id: "resolve_speakers", - needs: ["words"], - run(graph, ctx) { - const words = graph.words ?? []; - const frames = words.map((word, index) => { - const assignment = ctx.speakerState.assignmentByWordIndex.get(index); - const resolution = resolveSpeakerIdentity( - word, - assignment, - ctx.speakerState, - ); - rememberIdentity(word, assignment, resolution.identity, ctx.speakerState); +const applyExplicitAssignment: IdentityRule = (identity, { assignment }) => { + if (!assignment) { + return identity; + } - return { - word, - identity: resolution.identity, - provenance: resolution.provenance, - }; - }); + const updates: Partial = {}; + if (assignment.speaker_index !== undefined) { + updates.speaker_index = assignment.speaker_index; + } + if (assignment.human_id !== undefined) { + updates.human_id = assignment.human_id; + } - return { ...graph, frames }; - }, + return Object.keys(updates).length > 0 + ? { ...identity, ...updates } + : identity; +}; + +const applySpeakerIndexHumanId: IdentityRule = (identity, { snapshot }) => { + if (identity.speaker_index === undefined || identity.human_id !== undefined) { + return identity; + } + + const humanId = snapshot.humanIdBySpeakerIndex.get(identity.speaker_index); + if (humanId !== undefined) { + return { ...identity, human_id: humanId }; + } + + return identity; +}; + +const applyChannelHumanId: IdentityRule = (identity, { snapshot, word }) => { + if (identity.human_id !== undefined) { + return identity; + } + + if (!snapshot.completeChannels.has(word.channel)) { + return identity; + } + + const humanId = snapshot.humanIdByChannel.get(word.channel); + if (humanId !== undefined) { + return { ...identity, human_id: humanId }; + } + + return identity; +}; + +const carryPartialIdentityForward: IdentityRule = ( + identity, + { snapshot, word }, +) => { + if ( + word.isFinal || + (identity.speaker_index !== undefined && identity.human_id !== undefined) + ) { + return identity; + } + + const last = snapshot.lastSpeakerByChannel.get(word.channel); + if (!last) { + return identity; + } + + const updates: Partial = {}; + if ( + identity.speaker_index === undefined && + last.speaker_index !== undefined + ) { + updates.speaker_index = last.speaker_index; + } + if (identity.human_id === undefined && last.human_id !== undefined) { + updates.human_id = last.human_id; + } + + return Object.keys(updates).length > 0 + ? { ...identity, ...updates } + : identity; }; diff --git a/apps/desktop/src/utils/segment/shared.ts b/apps/desktop/src/utils/segment/shared.ts index 8e162a9e5..c8e81dd07 100644 --- a/apps/desktop/src/utils/segment/shared.ts +++ b/apps/desktop/src/utils/segment/shared.ts @@ -51,6 +51,28 @@ export const SegmentKey = { speaker_human_id: string; }>, ): SegmentKey => Data.struct(params), + + hasSpeakerIdentity: (key: SegmentKey): boolean => { + return ( + key.speaker_index !== undefined || key.speaker_human_id !== undefined + ); + }, + + equals: (a: SegmentKey, b: SegmentKey): boolean => { + return ( + a.channel === b.channel && + a.speaker_index === b.speaker_index && + a.speaker_human_id === b.speaker_human_id + ); + }, + + serialize: (key: SegmentKey): string => { + return JSON.stringify([ + key.channel, + key.speaker_index ?? null, + key.speaker_human_id ?? null, + ]); + }, }; export type SegmentBuilderOptions = { @@ -62,31 +84,18 @@ export type StageId = | "normalize_words" | "resolve_speakers" | "build_segments" - | "propagate_identity" - | "merge_segments"; + | "propagate_identity"; export type SpeakerIdentity = { speaker_index?: number; human_id?: string; }; -export type IdentityProvenance = - | "explicit_assignment" - | "speaker_index_lookup" - | "channel_completion" - | "last_speaker"; - export type NormalizedWord = SegmentWord & { order: number }; export type ResolvedWordFrame = { word: NormalizedWord; identity?: SpeakerIdentity; - provenance: IdentityProvenance[]; -}; - -export type SpeakerIdentityResolution = { - identity: SpeakerIdentity; - provenance: IdentityProvenance[]; }; export type ProtoSegment = { @@ -102,10 +111,14 @@ export type SegmentGraph = { segments?: ProtoSegment[]; }; -export type SegmentPass = { +type RequireKeys = Omit & Required>; + +export type SegmentPass = { id: StageId; - needs?: (keyof SegmentGraph)[]; - run: (graph: SegmentGraph, ctx: SegmentPassContext) => SegmentGraph; + run: ( + graph: RequireKeys, + ctx: SegmentPassContext, + ) => SegmentGraph; }; export type SegmentPassContext = {