# Assignment 5: Nonparametric Bayesian Segmentation

## Setup


Downloading data

In [None]:
# from https://medium.com/@Keshav31/colab-features-download-and-upload-e1ec537a83df
from urllib.request import urlretrieve
import os
from zipfile import ZipFile

url = 'https://ttic.uchicago.edu/~kgimpel/teaching/31210-s19/data/31210-s19-hw5.zip'
file = '31210-s19-hw5.zip'

if not os.path.isfile(file):
    urlretrieve(url,file)

with ZipFile(file) as zipf:
    zipf.extractall()

!rm -rf data/
!mkdir data/
!mv 31210-s19-hw5/* data/
!rm -rf 31210-s19-hw5.zip 31210-s19-hw5/

Parsing in data

In [None]:
with open('data/cbt-characters.txt') as f:
    sentences = f.readlines()
    sentences = [s.strip() for s in sentences]

with open('data/cbt-boundaries.txt') as f:
    gs_boundaries = f.readlines()
    gs_boundaries = [b.strip() for b in gs_boundaries]

    
assert len(sentences) == len(gs_boundaries)
for b in gs_boundaries:
    for bi in b:
        assert bi in ['0', '1']
        
# build character set for p_char
char_set = set()
for s in sentences:
    for l in s:
        char_set.add(l)

print('sentences loaded and validated.')
print(f'{len(char_set)} characters found.')

sentences loaded and validated.
54 characters found.


Calculate BPA for unsegmented corpus

In [None]:
def calc_bpa(gs_boundaries, boundaries):
    assert len(gs_boundaries) == len(boundaries)
    bound_ctr = 0
    correct = 0
    total = 0
    for sent_ix in range(len(sentences)):
        for char in range(len(sentences[sent_ix][:-1])):
            total += 1
            correct += int(boundaries[sent_ix][char] == gs_boundaries[sent_ix][char])
        bound_ctr += 1
    return correct/total
        
unseg_boundaries = [['0' for _ in sent[:-1]] + ['1'] for sent in sentences]
print(f'unsegmented BPA: {calc_bpa(gs_boundaries, unseg_boundaries)}')

unsegmented BPA: 0.7327530010591656


## 1. Gibbs Sampling Implementation

In [None]:
import collections

TOT_TOK = '___TOTAL___'

char_probs = {l: 1/len(char_set) for l in char_set}
G_0_memo = {}

def G_0(y, BETA):
    global char_probs
    if (y, BETA) in G_0_memo:
        return G_0_memo[(y, BETA)]
    ret = (1-BETA)**(len(y) - 1) * BETA
    for y_i in y:
        ret *= char_probs[y_i]
    G_0_memo[(y, BETA)] = ret
    return ret

def initialize_seg_counts(sentences, boundaries):
    counter = collections.Counter() 
    for sx, sent in enumerate(sentences):
        start_idx = 0
        for i in range(len(sent)):
            if boundaries[sx][i] == '1':
                counter[sent[start_idx:i+1]] += 1
                counter[TOT_TOK] += 1
                start_idx = i+1
    return counter

# testing segment counter
print(initialize_seg_counts(['thecowsaysmoo'], [['0','0','1','0','0','1','0','0','0','1','0','0','1']]))

Counter({'___TOTAL___': 4, 'the': 1, 'cow': 1, 'says': 1, 'moo': 1})


In [None]:
import random

def choose_new_value(sentences, boundaries, counts, s_ix, b_ix, s, BETA, GAMMA):
    y_full_start = 0
        
    sentences = sentences[s_ix]
    cur_bnd = boundaries[s_ix]
    for i in range(b_ix-1, -1, -1):
        if cur_bnd[i] == 1:
            y_full_start = i + 1
            break
    y_full_stop = len(sentences)
    for i in range(b_ix+1, len(sentences)):
        if cur_bnd[i] == 1:
            y_full_stop = i + 1
            break
    
    y_full = sentences[y_full_start:y_full_stop + 1]
    y_prev = sentences[y_full_start:b_ix + 1]
    y_next = sentences[b_ix + 1:y_full_stop + 1]
    
    old_bi = cur_bnd[b_ix]
    
    if old_bi == 1:
        n_y_prev = counts[y_prev] - 1
        n_y_next = counts[y_next] - 1
        n_y_full = counts[y_full]
        n_y_total = counts[TOT_TOK] - 2
    else:
        n_y_prev = counts[y_prev]
        n_y_next = counts[y_next]
        n_y_full = counts[y_full] - 1
        n_y_total = counts[TOT_TOK] - 1
    
    p_bi_0 = (n_y_full + s*G_0(y_full, BETA)) / (n_y_total + s)
    p_bi_1 = (n_y_prev + s*G_0(y_prev, BETA))\
                * (1 - GAMMA)\
                * (n_y_next + int(y_prev == y_next) + s*G_0(y_next, BETA))\
                / ((n_y_total + s) * (n_y_total + 1 + s))
    
    
    new_bi = str(int(random.random() * (p_bi_0 + p_bi_1) < p_bi_0))[0]
    
    boundaries[s_ix][b_ix] = new_bi
    
    if old_bi == new_bi:
        return 0, boundaries
    elif old_bi == 1 and new_bi == 0:
        counts[y_full] += 1
        counts[y_prev] -= 1
        counts[y_next] -= 1
        counts[TOT_TOK] -= 1
    elif old_bi == 0 and new_bi == 1:
        counts[y_full] -= 1
        counts[y_prev] += 1
        counts[y_next] += 1
        counts[TOT_TOK] += 1

    return 1, boundaries
    
    
def gibbs_samp_iter(sentences, cur_boundaries, counts, s, BETA, GAMMA, n):
    nbc = 0
    for s_ix in range(len(sentences)):
        for b_ix in range(len(sentences[s_ix]) - 1):
            ct, cur_boundaries = choose_new_value(sentences, cur_boundaries, 
                         counts, s_ix, b_ix, s, BETA, GAMMA)
            nbc += ct
    print(f'finished iteration {n}')
    print(f'    BPA: {calc_bpa(gs_boundaries, cur_boundaries)}')
    print(f'    NBC: {nbc}')
    return cur_boundaries
            
    
def gibbs_init(sentences, GAMMA):
    bdys = []
    for sent in sentences:
        bdys += [[str(int(random.random() < GAMMA))[0] for _ in sent[:-1]] + ['1']]
    return bdys
    

N_ITERS = 5
final_bpa_hypers = {}
def gibbs_train(sentences, gs_boundaries, s, BETA, GAMMA):
    if (s, BETA, GAMMA) in final_bpa_hypers:
        return
    cur_boundaries = gibbs_init(sentences, GAMMA)
    counts = initialize_seg_counts(sentences, cur_boundaries)
    for n in range(1, N_ITERS+1):
        cur_boundaries = gibbs_samp_iter(sentences, cur_boundaries, counts, s, BETA, GAMMA, n)
    final_bpa_hypers[(s,BETA,GAMMA)] = calc_bpa(gs_boundaries, cur_boundaries)
    return cur_boundaries

gibbs_train(sentences, gs_boundaries, 0.2, 0.5, 0.2)
# this ends up stablizing around 73.2% accuracy.

KeyboardInterrupt: ignored

## 2. Experimentation

### a. Changing hyperparameters

In [None]:
from tqdm import tqdm
import itertools

GAMMA_LIST = [0.1, 0.2, 0.3]
BETA_LIST = [0.1, 0.2, 0.5]
s_LIST = [0.1, 0.2, 0.5, 1]

for s, BETA, GAMMA in tqdm(list(itertools.product(s_LIST, BETA_LIST, GAMMA_LIST))):
    gibbs_train(sentences, gs_boundaries, s, BETA, GAMMA)

print(final_bpa_hypers)
# everything seemed to bes



  0%|          | 0/36 [00:00<?, ?it/s][A[A

finished iteration 1
    BPA: 0.7308264616863254
    NBC: 96607
finished iteration 2
    BPA: 0.7308286166520668
    NBC: 382
finished iteration 3
    BPA: 0.7308480113437397
    NBC: 376
finished iteration 4
    BPA: 0.7308329265835496
    NBC: 394
finished iteration 5
    BPA: 0.7308458563779983
    NBC: 384




  3%|▎         | 1/36 [01:22<48:08, 82.53s/it][A[A

finished iteration 1
    BPA: 0.7324685455812966
    NBC: 186280
finished iteration 2
    BPA: 0.7324793204100037
    NBC: 98
finished iteration 3
    BPA: 0.7324857853072281
    NBC: 96
finished iteration 4
    BPA: 0.7324900952387109
    NBC: 94
finished iteration 5
    BPA: 0.732478242927133
    NBC: 75




  6%|▌         | 2/36 [02:44<46:44, 82.48s/it][A[A

finished iteration 1
    BPA: 0.7326700348781205
    NBC: 278603
finished iteration 2
    BPA: 0.732672189843862
    NBC: 46
finished iteration 3
    BPA: 0.7326786547410863
    NBC: 64
finished iteration 4
    BPA: 0.7326624924980255
    NBC: 59
finished iteration 5
    BPA: 0.7326764997753448
    NBC: 59




  8%|▊         | 3/36 [04:07<45:20, 82.43s/it][A[A

finished iteration 1
    BPA: 0.7305452386570684
    NBC: 97326
finished iteration 2
    BPA: 0.7305829505575435
    NBC: 399
finished iteration 3
    BPA: 0.7305667883144827
    NBC: 401
finished iteration 4
    BPA: 0.7305969578348628
    NBC: 430
finished iteration 5
    BPA: 0.7305538585200342
    NBC: 396




 11%|█         | 4/36 [05:42<46:01, 86.28s/it][A[A

finished iteration 1
    BPA: 0.7323866568831222
    NBC: 186555
finished iteration 2
    BPA: 0.7323931217803465
    NBC: 134
finished iteration 3
    BPA: 0.7323920442974757
    NBC: 125
finished iteration 4
    BPA: 0.7323812694687686
    NBC: 138
finished iteration 5
    BPA: 0.7323834244345101
    NBC: 140




 14%|█▍        | 5/36 [07:05<44:06, 85.37s/it][A[A

finished iteration 1
    BPA: 0.7326560276008012
    NBC: 279441
finished iteration 2
    BPA: 0.732645252772094
    NBC: 62
finished iteration 3
    BPA: 0.7326549501179305
    NBC: 69
finished iteration 4
    BPA: 0.7326366329091283
    NBC: 73
finished iteration 5
    BPA: 0.7326409428406112
    NBC: 74




 17%|█▋        | 6/36 [08:29<42:22, 84.75s/it][A[A

finished iteration 1
    BPA: 0.7306863889131322
    NBC: 97311
finished iteration 2
    BPA: 0.7306863889131322
    NBC: 310
finished iteration 3
    BPA: 0.7307036286390637
    NBC: 326
finished iteration 4
    BPA: 0.7306971637418395
    NBC: 338
finished iteration 5
    BPA: 0.7306939312932272
    NBC: 321




 19%|█▉        | 7/36 [09:50<40:32, 83.89s/it][A[A

finished iteration 1
    BPA: 0.7325148773447374
    NBC: 186676
finished iteration 2
    BPA: 0.732505179998901
    NBC: 93
finished iteration 3
    BPA: 0.7325084124475131
    NBC: 97
finished iteration 4
    BPA: 0.7324976376188059
    NBC: 118
finished iteration 5
    BPA: 0.7325084124475131
    NBC: 104




 22%|██▏       | 8/36 [11:12<38:53, 83.34s/it][A[A

finished iteration 1
    BPA: 0.7327002043985006
    NBC: 279151
finished iteration 2
    BPA: 0.732694816984147
    NBC: 33
finished iteration 3
    BPA: 0.7326905070526641
    NBC: 30
finished iteration 4
    BPA: 0.7327002043985006
    NBC: 37
finished iteration 5
    BPA: 0.7326969719498885
    NBC: 31




 25%|██▌       | 9/36 [12:35<37:22, 83.07s/it][A[A

finished iteration 1
    BPA: 0.7303362069801496
    NBC: 97379
finished iteration 2
    BPA: 0.7303211222199595
    NBC: 436
finished iteration 3
    BPA: 0.7303135798398644
    NBC: 441
finished iteration 4
    BPA: 0.7303415943945031
    NBC: 458
finished iteration 5
    BPA: 0.7303308195657959
    NBC: 450




 28%|██▊       | 10/36 [13:57<35:52, 82.78s/it][A[A

finished iteration 1
    BPA: 0.7325267296563153
    NBC: 185827
finished iteration 2
    BPA: 0.7325299621049275
    NBC: 73
finished iteration 3
    BPA: 0.7325267296563153
    NBC: 69
finished iteration 4
    BPA: 0.732535349519281
    NBC: 64
finished iteration 5
    BPA: 0.732535349519281
    NBC: 64




 31%|███       | 11/36 [15:19<34:26, 82.68s/it][A[A

finished iteration 1
    BPA: 0.7326678799123791
    NBC: 278641
finished iteration 2
    BPA: 0.7326711123609913
    NBC: 47
finished iteration 3
    BPA: 0.7326764997753448
    NBC: 51
finished iteration 4
    BPA: 0.7326624924980255
    NBC: 43
finished iteration 5
    BPA: 0.7326743448096034
    NBC: 39




 33%|███▎      | 12/36 [16:42<33:01, 82.57s/it][A[A

finished iteration 1
    BPA: 0.730869561001154
    NBC: 97053
finished iteration 2
    BPA: 0.7308727934497661
    NBC: 377
finished iteration 3
    BPA: 0.7308900331756976
    NBC: 384
finished iteration 4
    BPA: 0.7308986530386633
    NBC: 378
finished iteration 5
    BPA: 0.7308706384840247
    NBC: 380




 36%|███▌      | 13/36 [18:05<31:41, 82.65s/it][A[A

finished iteration 1
    BPA: 0.7324836303414867
    NBC: 186383
finished iteration 2
    BPA: 0.7324760879613916
    NBC: 101
finished iteration 3
    BPA: 0.7324836303414867
    NBC: 93
finished iteration 4
    BPA: 0.7324750104785209
    NBC: 94
finished iteration 5
    BPA: 0.7324965601359352
    NBC: 104




 39%|███▉      | 14/36 [19:28<30:23, 82.89s/it][A[A

finished iteration 1
    BPA: 0.7326980494327592
    NBC: 278260
finished iteration 2
    BPA: 0.7326958944670177
    NBC: 50
finished iteration 3
    BPA: 0.7327034368471127
    NBC: 53
finished iteration 4
    BPA: 0.7327012818813713
    NBC: 46
finished iteration 5
    BPA: 0.7326991269156299
    NBC: 50




 42%|████▏     | 15/36 [20:51<29:03, 83.01s/it][A[A

finished iteration 1
    BPA: 0.730869561001154
    NBC: 96540
finished iteration 2
    BPA: 0.7308480113437397
    NBC: 376
finished iteration 3
    BPA: 0.7308447788951276
    NBC: 371
finished iteration 4
    BPA: 0.7308598636553175
    NBC: 356
finished iteration 5
    BPA: 0.730846933860869
    NBC: 382




 44%|████▍     | 16/36 [22:13<27:32, 82.62s/it][A[A

finished iteration 1
    BPA: 0.7325482793137296
    NBC: 185955
finished iteration 2
    BPA: 0.7325439693822468
    NBC: 102
finished iteration 3
    BPA: 0.7325493567966003
    NBC: 103
finished iteration 4
    BPA: 0.7325536667280832
    NBC: 86
finished iteration 5
    BPA: 0.7325418144165053
    NBC: 109




 47%|████▋     | 17/36 [23:35<26:07, 82.52s/it][A[A

finished iteration 1
    BPA: 0.7327131341929491
    NBC: 278627
finished iteration 2
    BPA: 0.7327152891586906
    NBC: 32
finished iteration 3
    BPA: 0.7327174441244321
    NBC: 38
finished iteration 4
    BPA: 0.7327142116758198
    NBC: 39
finished iteration 5
    BPA: 0.7327228315387856
    NBC: 40




 50%|█████     | 18/36 [24:58<24:44, 82.49s/it][A[A

finished iteration 1
    BPA: 0.7307154809506415
    NBC: 97234
finished iteration 2
    BPA: 0.7307111710191587
    NBC: 384
finished iteration 3
    BPA: 0.7307284107450902
    NBC: 380
finished iteration 4
    BPA: 0.7306982412247102
    NBC: 400
finished iteration 5
    BPA: 0.7307047061219344
    NBC: 348




 53%|█████▎    | 19/36 [26:20<23:20, 82.41s/it][A[A

finished iteration 1
    BPA: 0.7325094899303838
    NBC: 186301
finished iteration 2
    BPA: 0.7324857853072281
    NBC: 110
finished iteration 3
    BPA: 0.7324653131326845
    NBC: 117
finished iteration 4
    BPA: 0.7324922502044524
    NBC: 107
finished iteration 5
    BPA: 0.7324857853072281
    NBC: 112




 56%|█████▌    | 20/36 [27:43<22:01, 82.59s/it][A[A

finished iteration 1
    BPA: 0.7326754222924741
    NBC: 277592
finished iteration 2
    BPA: 0.7326668024295083
    NBC: 48
finished iteration 3
    BPA: 0.7326754222924741
    NBC: 48
finished iteration 4
    BPA: 0.7326754222924741
    NBC: 42
finished iteration 5
    BPA: 0.7326732673267327
    NBC: 50




 58%|█████▊    | 21/36 [29:06<20:38, 82.59s/it][A[A

finished iteration 1
    BPA: 0.7305980353177335
    NBC: 97223
finished iteration 2
    BPA: 0.730565710831612
    NBC: 478
finished iteration 3
    BPA: 0.7305840280404142
    NBC: 449
finished iteration 4
    BPA: 0.7306055776978285
    NBC: 420
finished iteration 5
    BPA: 0.7305904929376386
    NBC: 428




 61%|██████    | 22/36 [30:28<19:17, 82.68s/it][A[A