This notebook explores using the pomegranate library to train a HMM with 100 emission states (e), 100 context length (L), 100 hidden states (h).

Input file:

- `/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/data/TinyStories-100-train.txt`
- `/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/data/TinyStories-100-test.txt`

and the tokenizer

`/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/data/tokenizer.json`

Output file:

`/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/models/TinyStories-L-100-h-100-e-100-pomegranate.pkl`

In [1]:
%pylab inline
import seaborn; seaborn.set_style('whitegrid')

import torch

import numpy as np

numpy.random.seed(0)
numpy.set_printoptions(suppress=True)

%load_ext watermark
%watermark -m -n -p numpy,scipy,torch,pomegranate

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
numpy      : 1.26.4
scipy      : 1.14.0
torch      : 2.2.2
pomegranate: 1.1.0

Compiler    : GCC 11.3.0
OS          : Linux
Release     : 4.18.0-513.18.1.el8_9.x86_64
Machine     : x86_64
Processor   : x86_64
CPU cores   : 64
Architecture: 64bit



In [2]:
from pomegranate.hmm import DenseHMM, SparseHMM
from pomegranate.distributions import Categorical, Normal


num_emission = 100


def train_hmm(num_emission, num_hidden_states, data, verbose=False):
    # distributions is the array of emission probabilities
    distributions = [Categorical(probs=[[1/num_emission] * num_emission])] * int(num_hidden_states)
    # edges are the transition matrix
    edges = [[1/num_hidden_states] * num_hidden_states] * num_hidden_states
    model = SparseHMM(distributions=distributions, edges=edges, max_iter=3, verbose=True)
    model.fit(data)
    
    return model

In [3]:
# read the data and divide it into chunks of length context length

def get_hmm_train_data(num_emission, context_length):
    train_data_file_name = f"/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/data/TinyStories-{num_emission}-train.txt"
    with open(train_data_file_name, 'r') as file:
        lines = file.readlines()
    
    # put all lines on one line
    text = ' '.join([line.strip() for line in lines])
    
    # parse as tokens
    tokens = [int(token) for token in text.split()]
    
    # get chunks of size context length
    chunks = []
    idex = 0
    while idex+context_length < len(tokens):
        chunks.append(tokens[idex:idex+context_length])
        idex += context_length
    
    # for above, note that it is possible that the final bits of the text will be cut off
    # the resulting shape has to be (batch_size, sequence length, dimensionality).
    # chunks = numpy.array(chunks, shape=(len(chunks), context_length, 1))
    chunks = np.expand_dims(np.array(chunks), axis=-1)
    return chunks

In [4]:
# combinations

class ModelParams:
    L: int
    h: int
    e: int
    
    def __init__(self, L, h, e):
        self.L = L
        self.h = h
        self.e = e
        
    def __str__(self):
        return f"ModelParams(L={self.L}, h={self.h}, e={self.e})"

Ls = [100, 200, 400, 800, 1600]
hs = [100, 200, 400, 800, 1600]

combinations = [ModelParams(L=L, h=h, e=100) for L in Ls for h in hs]

In [5]:
# test out the simplest combination
simple_model_params = combinations[0]
str(simple_model_params)

'ModelParams(L=100, h=100, e=100)'

In [6]:
simple_train_data = get_hmm_train_data(num_emission=simple_model_params.e, context_length=simple_model_params.L)

In [7]:
simple_train_data.shape

(74514, 100, 1)

In [8]:
# for testing, just use 100 examples
simple_train_data = simple_train_data[:200]
simple_train_data.shape

(200, 100, 1)

In [9]:
num_emission=simple_model_params.e
num_hidden_states=simple_model_params.h
data=simple_train_data
verbose=True

In [36]:
Categorical([[0.5,0.2,0.3],[1,0,0]]).sample(10)

tensor([[0, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])

In [40]:
# distributions is the array of emission probabilities
distributions = [Categorical(probs=[[np.float32(1/num_emission)] * num_emission])] * int(num_hidden_states)

model = DenseHMM(distributions=distributions, max_iter=3, verbose=True)
model.fit(data, sample_weight=None)
simple_model = model

[1] Improvement: nan, Time: 0.6401s
[2] Improvement: nan, Time: 0.6426s
[3] Improvement: nan, Time: 1.285s


In [103]:
simple_model = train_hmm(
    num_emission=simple_model_params.e, num_hidden_states=simple_model_params.h, data=simple_train_data, verbose=True
)

ValueError: Each edge must have three elements.

In [88]:
# save the model
import pickle


file_name = f"/n/netscratch/sham_lab/Everyone/jchooi/in-context-language-learning/data/hmm-L-{simple_model_params.L}-h-{simple_model_params.h}-e-{simple_model_params.e}-pomegranate.pkl"

with open(file_name, "wb") as file:
    pickle.dump(simple_model, file)

In [89]:
with open(file_name, 'rb') as file:
	# introducing types to get the intellisense to work
	model: DenseHMM = pickle.load(file)

In [None]:
simple_model.state_dict()

OrderedDict([('_device', tensor([0.])),
             ('edges',
              tensor([[nan, nan, nan,  ..., nan, nan, nan],
                      [nan, nan, nan,  ..., nan, nan, nan],
                      [nan, nan, nan,  ..., nan, nan, nan],
                      ...,
                      [nan, nan, nan,  ..., nan, nan, nan],
                      [nan, nan, nan,  ..., nan, nan, nan],
                      [nan, nan, nan,  ..., nan, nan, nan]])),
             ('starts',
              tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 

In [None]:
model.state_dict()

OrderedDict([('_device', tensor([0.])),
             ('starts',
              tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan])),
             ('ends',
              tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                      nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
                 