## Linear Probe for Word Boundaries

Adding a linear probe to the pre-trained model to see if it can detect word boundaries.

In [1]:
import hashlib
import os
import yaml
import scipy.stats
import pandas as pd
import numpy as np
import sys
from IPython.display import clear_output

import torch
import torch.nn as nn
from src.model.model import next_char_transformer
from src.segmentation.probing import WordBoundaryDataset, BoundaryProbe

In [2]:
#run_dir = 'wandb/run-20221110_041703-6926109/files' # old model, aligned
run_dir = 'wandb/run-20221203_063741-9317928/files' # new model
# run_dir = 'wandb/run-20221216_170649-829928499/files' # number model

seed = 32

checkpoint_path = run_dir + '/best.pt'
with open(run_dir + '/config.yaml', 'r') as file:
    config = yaml.safe_load(file)

## Load Corpus

In [3]:
data_dir = config['root_path']['value']
fn = 'corpus.{}.data'.format('.'.join(data_dir.split('/')))
if os.path.exists(fn):
    print('Loading cached dataset...')
    corpus = torch.load(fn)
    ntokens = len(corpus.dictionary)
else:
    print('No precached dataset found')
    raise Exception('No precached dataset found')
print (corpus.dictionary.word2idx)

Loading cached dataset...
{'<PAD>': 0, '<BOUNDARY>': 1, 'dʒ': 2, 'ʌ': 3, 's': 4, 't': 5, 'l': 6, 'aɪ': 7, 'k': 8, 'j': 9, 'ʊɹ': 10, 'b': 11, 'ʊ': 12, 'æ': 13, 'h': 14, 'oʊ': 15, 'm': 16, 'd': 17, 'uː': 18, 'w': 19, 'ɑː': 20, 'n': 21, 'ə': 22, 'ð': 23, 'ɐ': 24, 'ɾ': 25, 'ɪ': 26, 'ɛ': 27, 'z': 28, 'iː': 29, 'ɛɹ': 30, 'f': 31, 'eɪ': 32, 'ɡ': 33, 'ᵻ': 34, 'p': 35, 'i': 36, 'əl': 37, 'tʃ': 38, 'θ': 39, 'ŋ': 40, 'oːɹ': 41, 'ɹ': 42, 'ɔɪ': 43, 'ɔː': 44, 'aʊ': 45, 'ɪɹ': 46, 'v': 47, 'ɜː': 48, 'ɚ': 49, 'ɑːɹ': 50, 'ɔːɹ': 51, 'ɔ': 52, 'ʃ': 53, 'æː': 54, 'aɪɚ': 55, 'iə': 56, 'ʔ': 57, 'n̩': 58, 'oː': 59, 'aɪə': 60, 'ʒ': 61, 'aɪʊɹ': 62, 'ɑ̃': 63, 'r': 64, 'ɫ': 65, 'ɬ': 66, 'aɪʊ': 67, 'ɛː': 68, 'ɐː': 69, 'nʲ': 70, 'x': 71, '(es)': 72, 'o': 73, 'a': 74, '(enus)': 75}


In [4]:
from torch.utils.data import DataLoader

training_data = WordBoundaryDataset('data/Eng-NA/valid.txt', corpus, max_sequence_length=config['sequence_length']['value'])
test_data = WordBoundaryDataset('data/Eng-NA/test.txt', corpus, max_sequence_length=config['sequence_length']['value'])

train_dataloader = DataLoader(training_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

## Define New Model

## Load Model

In [5]:
# Set the random seed manually for reproducibility.
torch.manual_seed(seed)

# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Loading on device: {device}')

Loading on device: cpu


In [6]:
with open(checkpoint_path, 'rb') as f:
    checkpoint = torch.load(f, map_location=device)
model = next_char_transformer(ntokens,
                                n_layers=config['n_layers']['value'],
                                hidden_size=config['hidden_size']['value'],
                                inner_linear=config['inner_linear']['value'],
                                max_sequence_len=config['sequence_length']['value']).to(device)
model.load_state_dict(checkpoint['learner_state_dict'], strict=False)
model.eval()

probe = BoundaryProbe(model).to(device)

In [7]:
input, target = train_dataloader.dataset[0]
probe(input.unsqueeze(0).to(device), torch.ones(1, input.shape[0]).to(device))

tensor([[[-0.9061,  0.6991],
         [-0.6792,  1.1001],
         [ 0.0355,  0.6220],
         [-0.3786,  1.4314],
         [-0.6435,  1.8600],
         [-0.2370,  0.9272],
         [-1.1037,  0.8358],
         [-0.2219,  0.7312],
         [-0.5623,  0.8842],
         [-0.1591,  0.4556],
         [-0.0991,  0.4577],
         [-0.8836,  1.0321],
         [-0.0318,  1.0325],
         [-0.0408,  0.4338],
         [-0.7807,  0.8874],
         [-1.1760,  0.5431],
         [-0.5551,  1.5818]]], grad_fn=<ViewBackward0>)

In [9]:
NUM_EPOCHS = 10
# BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

import torch.optim as optim
import torch.nn.functional as F
from src.data.data import subsequent_mask
from src.segmentation.evaluate import evaluate

optimizer = optim.SGD(probe.classifier_layer.parameters(), lr=0.001, momentum=0.9)
length = len(train_dataloader)

for epoch in range(NUM_EPOCHS):
    
    i = 0
    for phonemes, boundaries in iter(train_dataloader):
        phonemes = phonemes.to(device)
        boundaries = boundaries.to(device)
        optimizer.zero_grad()
        mask = subsequent_mask(phonemes.shape[1])
        outputs = probe(phonemes, mask)
        loss = F.cross_entropy(outputs[0], boundaries[0])
        loss.backward()
        optimizer.step()
        i+=1
        if i % 100 == 0:
            print('Epoch: %d, Loss: %f, Batch: %d/%d' % (epoch, loss.item(), i, length))
    
    test_error_count = 0.0
    total_boundaries = 0
    gold_utterances = []
    predicted_utterances = []
    for phonemes, boundaries in iter(test_dataloader):
        phonemes = phonemes.to(device)
        boundaries = boundaries.to(device)
        mask = subsequent_mask(phonemes.shape[1])
        outputs = probe(phonemes, mask)
        test_error_count += float(torch.sum(torch.abs(boundaries[0] - outputs[0].argmax(1))))
        total_boundaries += outputs[0].shape[0]
        predicted_boundaries = outputs[0].argmax(1)
        gold_utterances.append(' '.join([(';eword ' if b.item() else '') + corpus.dictionary.idx2word[c.item()] for c, b in zip(phonemes[0,1:], boundaries[0,1:])]))
        predicted_utterances.append(' '.join([(';eword ' if b.item() else '') + corpus.dictionary.idx2word[c.item()] for c, b in zip(phonemes[0,1:], predicted_boundaries[1:])]))
    
    results = evaluate(gold_utterances, predicted_utterances)
    test_accuracy = 1.0 - float(test_error_count) / total_boundaries
    print('%d: %f' % (epoch, test_accuracy))
    print(results)
    if test_accuracy > best_accuracy:
        # torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

Epoch: 0, Loss: 0.102730, Batch: 100/1000
Epoch: 0, Loss: 0.303036, Batch: 200/1000
Epoch: 0, Loss: 0.120319, Batch: 300/1000
Epoch: 0, Loss: 0.054726, Batch: 400/1000
Epoch: 0, Loss: 0.025972, Batch: 500/1000
Epoch: 0, Loss: 0.056543, Batch: 600/1000
Epoch: 0, Loss: 0.186891, Batch: 700/1000
Epoch: 0, Loss: 0.123649, Batch: 800/1000
Epoch: 0, Loss: 0.172437, Batch: 900/1000
Epoch: 0, Loss: 0.059721, Batch: 1000/1000
0: 0.954726
OrderedDict([('token_precision', 0.7925716084356311), ('token_recall', 0.808346709470305), ('token_fscore', 0.8003814367450731), ('type_precision', 0.7696793002915452), ('type_recall', 0.6583541147132169), ('type_fscore', 0.7096774193548387), ('boundary_all_precision', 0.9365573378022504), ('boundary_all_recall', 0.9506682867557715), ('boundary_all_fscore', 0.9435600578871202), ('boundary_noedge_precision', 0.8782728525493799), ('boundary_noedge_recall', 0.9040189125295508), ('boundary_noedge_fscore', 0.8909599254426841)])
Epoch: 1, Loss: 0.086875, Batch: 100/1

KeyboardInterrupt: 