# Train a segment classifier
Given two texts, train a classifier to predict whether they belong to the same segment.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import datetime
import json
import math
import os

from transformers import pipeline

from models.segment_eval import eval_segment_boundaries

In [None]:
# configure
input_dir = '../data/segment/labeled/'
in_filename = '2023-03-18.json'

output_dir = '../data/segment/model/'
today = datetime.today().strftime('%Y-%m-%d')


In [None]:
pipe = pipeline("text-classification", model="dennlinger/bert-wiki-paragraphs")

In [None]:
# read labeled data
with open(os.path.join(input_dir, in_filename)) as f:
    talk_sections = json.load(f)

In [None]:
# predict segments using bert-wiki-paragraphs to begin
def predict_segments(paragraphs, threshold):
    current = 1
    segments = [current]
    for i in range(1, len(paragraphs)):
        para_pair = f"{paragraphs[i-1]} [SEP] {paragraphs[i]}"
        result = pipe(para_pair, truncation=True)[0]
        if result['score'] < threshold:
            current += 1
        segments.append(current)
    return segments

In [None]:
# predict segments for each talk_section
for talk_section in talk_sections:
    paragraphs = [paragraph_segment['text'] for paragraph_segment in talk_section['paragraphs']]
    true_segments = [paragraph_segment['segment'] for paragraph_segment in talk_section['paragraphs']]
    print(true_segments)
    # predict
    pred_segments = predict_segments(paragraphs, 0.85)
    print(pred_segments)
    # eval - lower is better
    pk_diff, window_diff = eval_segment_boundaries(true_segments, pred_segments)
    print(pk_diff, window_diff)

Consider training a model by generating pairs of paragraph n-grams: (paragraph n-grams, paragraph n-grams). For each sequence of paragraph n-grams in the pair, the paragraphs must all belong to the same segment. Label the pair with a 1 if both sequences of paragraph n-grams belong to the same segment, or a 0 if the sequences of paragraph n-grams belong to different segments.