In [None]:
import sys
sys.path.append("../..")
from model.hmm import data, models, training # Generate datasets from text file
path = "."
N = 128
config = data.read_config(N, path)
train_dataset, valid_dataset = data.get_datasets(config, batch_size=128)
checkpoint_path = "."

# Initialize model
model = models.HMM(config=config)

# Train the model
num_epochs = 10
trainer = training.Trainer(model, config, lr=0.003)
trainer.load_checkpoint(checkpoint_path)

for epoch in range(num_epochs):
    print("========= Epoch %d of %d =========" % (epoch + 1, num_epochs))
    train_loss = trainer.train(train_dataset, print_interval=100)
    valid_loss = trainer.test(valid_dataset)
    trainer.save_checkpoint(epoch, checkpoint_path)

    print("========= Results: epoch %d of %d =========" % (epoch + 1, num_epochs))
    print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss))


  0%|                                                                                         | 0/1659 [00:00<?, ?it/s]

Could not load previous model; starting from scratch


  0%|                                                                                 | 1/1659 [00:00<15:17,  1.81it/s]

35.19606018066406
asAyJruxxl; goyAiSThlr; p
HPkrxuAi; cJcOvtBbWt; oypqdmOfCf; 


  6%|████▊                                                                          | 101/1659 [00:44<11:19,  2.29it/s]

30.662067413330078
awxslsjudV; cuhmxhWsbv; roAmEengiv; piurvllduz; uner
dLekd; 


 12%|█████████▌                                                                     | 201/1659 [01:27<10:40,  2.28it/s]

29.658903121948242
chuXdepryt; rotheromwz; sbhoysemrN; aiwvYentnz; drnspprmXr; 


 18%|██████████████▎                                                                | 301/1659 [02:10<09:53,  2.29it/s]

26.517515182495117
aromtreylh; cngilissls; gryudlnode; aumQipesvt; namzicgior; 


 24%|███████████████████                                                            | 401/1659 [02:55<09:12,  2.28it/s]

27.164794921875
afodikryma; selpryotau; trtglaurvM; noslReafop; hoaHkitvin; 


 30%|███████████████████████▊                                                       | 501/1659 [03:39<08:24,  2.30it/s]

27.29253387451172
rnerlaliso; abttrichea; gtadfleeri; ialooyBctN; lemileruyl; 


 31%|████████████████████████▎                                                      | 511/1659 [03:44<07:44,  2.47it/s]

In [6]:
import numpy as np
N = len(train_dataset.Sx)

embedding = torch.nn.Embedding(N, embedding_dim=256)
lr = 1e-3
optimizer = torch.optim.Adam(embedding.parameters(), lr=lr)

batch_size = 1024
n_epochs = 20
print_every = 1
k = 2 # skip-gram
device = torch.device('cuda')
embedding = embedding.to(device)
for epoch_idx in range(n_epochs):
    try:
        batch_loss = []
        for batch_idx, (x, T) in enumerate(train_dataset.loader):
            x = x.cpu().numpy()
            T = T.cpu().numpy()
            i = np.random.randint(0, T).astype(np.int64)
            j = np.random.randint(np.maximum(0, i-k), np.minimum(T, i+k+1)).astype(np.int64)
            i, j = x[np.arange(len(T)), i], x[np.arange(len(T)), j]
            i = np.concatenate([i, np.random.randint(0, N, x.shape[0])])
            j = np.concatenate([j, np.random.randint(0, N, x.shape[0])])
            targets = torch.tensor([np.repeat([0, 1], x.shape[0]).astype(np.float64)]).to(device)
            i = torch.tensor(i).to(device)
            j = torch.tensor(j).to(device)
            optimizer.zero_grad()
            xi = embedding(i)
            xj = embedding(j)
            loss = torch.binary_cross_entropy_with_logits(torch.diag(torch.matmul(xi, xj.T)), targets)
            loss = torch.sum(loss)
            batch_loss.append(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedding.parameters(), 0.5)
            optimizer.step()
        if epoch_idx % print_every == 0:
            print("Epoch {}: loss = {:.3f}".format(epoch_idx, np.mean(batch_loss)))
    except KeyboardInterrupt:
        break

Epoch 0: loss = 397.716
Epoch 1: loss = 43.339
Epoch 2: loss = 36.505
Epoch 3: loss = 35.316
Epoch 4: loss = 34.732
Epoch 5: loss = 34.204
Epoch 6: loss = 34.027
Epoch 7: loss = 33.808
Epoch 8: loss = 33.905
Epoch 9: loss = 33.608
Epoch 10: loss = 33.832
Epoch 11: loss = 33.999
Epoch 12: loss = 33.983
Epoch 13: loss = 33.981
Epoch 14: loss = 33.877
Epoch 15: loss = 33.741
Epoch 16: loss = 33.878
Epoch 17: loss = 33.567
Epoch 18: loss = 33.902
Epoch 19: loss = 34.884


In [26]:
import torch


class HMM(torch.nn.Module):
    """
    Hidden Markov Model.
    (For now, discrete observations only.)
    - forward(): computes the log probability of an observation sequence.
    - viterbi(): computes the most likely state sequence.
    - sample(): draws a sample from p(x).
    """

    def __init__(self, config, embedding):
        super(HMM, self).__init__()
        self.M = config.M  # number of possible observations
        self.N = config.N  # number of states
        self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))
        self.transition_model = TransitionModel(self.N, embedding)
        self.emission_model = EmissionModel(self.N, self.M)
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda:
            self.cuda()

    def forward(self, x, T):
        """
        x : IntTensor of shape (batch size, T_max)
        T : IntTensor of shape (batch size)

        Compute log p(x) for each example in the batch.
        T = length of each example
        """
        if self.is_cuda:
            x = x.cuda()
            T = T.cuda()

        batch_size = x.shape[0]
        T_max = x.shape[1]
        log_state_priors = torch.nn.functional.log_softmax(
            self.unnormalized_state_priors, dim=0
        )
        log_alpha = torch.zeros(batch_size, T_max, self.N)
        if self.is_cuda:
            log_alpha = log_alpha.cuda()

        log_alpha[:, 0, :] = self.emission_model(x[:, 0]) + log_state_priors
        for t in range(1, T_max):
            log_alpha[:, t, :] = self.emission_model(x[:, t]) + self.transition_model(
                x[:,t], log_alpha[:, t - 1, :], use_max=False
            )

        log_sums = log_alpha.logsumexp(dim=2)

        # Select the sum for the final timestep (each x has different length).
        log_probs = torch.gather(log_sums, 1, T.view(-1, 1) - 1)
        return log_probs

    def sample(self, T=10):
        state_priors = torch.nn.functional.softmax(
            self.unnormalized_state_priors, dim=0
        )
        emission_matrix = torch.nn.functional.softmax(
            self.emission_model.unnormalized_emission_matrix, dim=1
        )

        # sample initial state
        z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
        z = []
        x = []
        z.append(z_t)
        for t in range(0, T):
            # sample emission
            x_t = (
                torch.distributions.categorical.Categorical(emission_matrix[z_t])
                .sample()
            )
            x.append(x_t.item())

            # sample transition
            transition_matrix = self.transition_model.log_transition_matrix(x_t.unsqueeze(0))[0]
            z_t = (
                torch.distributions.categorical.Categorical(transition_matrix[:, z_t])
                .sample()
                .item()
            )
            if t < T - 1:
                z.append(z_t)

        return x, z

    def viterbi(self, x, T):
        """
        x : IntTensor of shape (batch size, T_max)
        T : IntTensor of shape (batch size)

        Find argmax_z log p(z|x) for each (x) in the batch.
        """
        if self.is_cuda:
            x = x.cuda()
            T = T.cuda()

        batch_size = x.shape[0]
        T_max = x.shape[1]
        log_state_priors = torch.nn.functional.log_softmax(
            self.unnormalized_state_priors, dim=0
        )
        log_delta = torch.zeros(batch_size, T_max, self.N).float()
        psi = torch.zeros(batch_size, T_max, self.N).long()
        if self.is_cuda:
            log_delta = log_delta.cuda()
            psi = psi.cuda()

        log_delta[:, 0, :] = self.emission_model(x[:, 0]) + log_state_priors
        for t in range(1, T_max):
            max_val, argmax_val = self.transition_model(
                x[:,t], log_delta[:, t - 1, :], use_max=True
            )
            log_delta[:, t, :] = self.emission_model(x[:, t]) + max_val
            psi[:, t, :] = argmax_val

        # Get the probability of the best path
        log_max = log_delta.max(dim=2)[0]
        best_path_scores = torch.gather(log_max, 1, T.view(-1, 1) - 1)

        # This next part is a bit tricky to parallelize across the batch,
        # so we will do it separately for each example.
        z_star = []
        for i in range(0, batch_size):
            z_star_i = [log_delta[i, T[i] - 1, :].max(dim=0)[1].item()]
            for t in range(T[i] - 1, 0, -1):
                z_t = psi[i, t, z_star_i[0]].item()
                z_star_i.insert(0, z_t)

            z_star.append(z_star_i)

        return z_star, best_path_scores


def log_domain_matmul(log_A, log_B, use_max=False):
    """
    log_A : m x n
    log_B : n x p

    output : m x p matrix

    Normally, a matrix multiplication
    computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

    A log domain matrix multiplication
    computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}

    This is needed for numerical stability
    when A and B are probability matrices.
    """
    offset = 1 if len(log_A.size()) > 2 else 0
    m = log_A.shape[0+offset]
    n = log_A.shape[1+offset]
    p = log_B.shape[1+offset]

    log_A_expanded = torch.stack([log_A] * p, dim=2+offset)
    log_B_expanded = torch.stack([log_B] * m, dim=0+offset)

    elementwise_sum = log_A_expanded + log_B_expanded
    out = torch.logsumexp(elementwise_sum, dim=1+offset)
    return out

    
class TransitionModel(torch.nn.Module):
    """
    - forward(): computes the log probability of a transition.
    - sample(): given a previous state, sample a new state.
    """

    def __init__(self, N, embedding, hidden_dim=256, dropout=0.5):
        super(TransitionModel, self).__init__()
        self.N = N  # number of states
        self.embedding = embedding
        self.unnormalized_transition_matrix = torch.nn.Sequential(
            torch.nn.Linear(embedding.embedding_dim, hidden_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, N*N)
        )
        self.base_transition_matrix = torch.nn.Parameter(torch.randn(N, N).unsqueeze(0))
    
    def log_transition_matrix(self, x):
        return torch.nn.functional.log_softmax(
            #self.unnormalized_transition_matrix(self.embedding(x)).view(x.size(0), N, N) + 
            self.base_transition_matrix, 
            dim=1
        )

    def forward(self, x, log_alpha, use_max):
        """
        log_alpha : Tensor of shape (batch size, N)

        Multiply previous timestep's alphas by transition matrix (in log domain)
        """
        # Each col needs to add up to 1 (in probability domain)
        log_transition_matrix = self.log_transition_matrix(x)
        
        # Expand matrix to do batchwise multiplication
        log_alpha = log_alpha.unsqueeze(-1)

        # Matrix multiplication in the log domain
        # out = genbmm.logbmm(log_alpha.unsqueeze(0).contiguous(), transition_matrix.unsqueeze(0).contiguous())[0]
        if use_max:
            out1, out2 = maxmul(log_transition_matrix, log_alpha)
            return out1.transpose(0, 1), out2.transpose(0, 1)
        else:
            out = log_domain_matmul(log_transition_matrix, log_alpha)
            out = out.squeeze(-1)
            return out

In [27]:
# Generate datasets from text file
path = "."
N = 128
config = read_config(N, path)
train_dataset, valid_dataset = get_datasets(config)
checkpoint_path = "."

# Initialize model
model = HMM(config=config, embedding=embedding)

# Train the model
num_epochs = 10
trainer = Trainer(model, config, lr=0.003)
#trainer.load_checkpoint(checkpoint_path)

for epoch in range(num_epochs):
    print("========= Epoch %d of %d =========" % (epoch + 1, num_epochs))
    train_loss = trainer.train(train_dataset)
    valid_loss = trainer.test(valid_dataset)
    trainer.save_checkpoint(epoch, checkpoint_path)

    print("========= Results: epoch %d of %d =========" % (epoch + 1, num_epochs))
    print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss))


35.482635498046875
TxFGLhCp-k
[57, 102, 127, 116, 115, 115, 44, 31, 89, 45]
wHpucTsRJC
[125, 118, 47, 61, 53, 113, 107, 120, 60, 127]
QHfaGEvBqV
[41, 71, 127, 113, 109, 124, 105, 50, 64, 43]
VBQjGRADIQ
[12, 46, 33, 105, 110, 114, 102, 55, 95, 41]
JAjiWNyqod
[108, 26, 19, 96, 94, 7, 22, 97, 62, 92]
23.872852325439453
dhSnerhtua
[115, 108, 120, 127, 44, 56, 68, 56, 16, 27]
pnhcuhgdvy
[64, 114, 116, 71, 87, 67, 85, 59, 15, 101]
adileispYm
[108, 4, 110, 75, 16, 122, 107, 79, 46, 119]
ceheoapntO
[84, 120, 26, 1, 81, 55, 99, 22, 79, 50]
iqUlacedut
[97, 78, 36, 116, 27, 84, 35, 99, 76, 25]
23.12369728088379
hpbbbggmpl
[97, 107, 41, 38, 84, 39, 98, 113, 64, 37]
jjxdlgrlsp
[83, 109, 113, 92, 89, 85, 112, 43, 19, 34]
omiymogfyb
[92, 12, 110, 3, 84, 22, 39, 58, 21, 41]
qbFuxudhqh
[64, 68, 57, 52, 6, 96, 99, 78, 94, 76]
pokmiabsrw
[41, 73, 104, 49, 7, 3, 33, 74, 111, 77]
24.12399673461914
sduzmprtdl
[73, 99, 65, 20, 82, 34, 69, 106, 98, 9]
fmylaufipu
[118, 119, 120, 82, 48, 65, 40, 98, 62, 96]
umu

KeyboardInterrupt: 