In [None]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)
torch.__version__

# Helper functions

In [None]:
def log_sum_exp(t, dim):
    # t: nD tensor
    # dim: dim to sum along
    # ret: nD tensor
    # log is natural logrithm.
    # Note: Will keep dim.
    
    max_val, _ = torch.max(t,dim=dim,keepdim=True)  # nD tensor
    return torch.log(torch.sum(torch.exp(t - max_val), dim=dim, keepdim=True)) + max_val
    

In [None]:
# Test for log_sum_exp()
t = torch.tensor([[1,1,1,1],[1,1,1,1]],dtype=torch.float64)
# log_sum_exp(t,1)
log_sum_exp(t,0)
# Check ln(4e) == 2.3863
# Check ln(2e) == 1.6931

# BiLSTM-CRF

In [None]:
# L: Length of input sequence
# B: Batch size
# T: Tag set size
# E: Embedding dim
# H: Hidden dim
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim):
        
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim   # E
        self.hidden_dim = hidden_dim         # H
        self.vocab_size = vocab_size
        self.tagset_size = tagset_size
        
        self.START_TAG_IDX = tagset_size - 2  # the second to last is start tag
        self.STOP_TAG_IDX = tagset_size - 1   # the last is stop tag
        
        self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim, 
                            self.hidden_dim // 2,  # This is bidirectional. Each direction gets H//2
                            num_layers=1,
                            bidirectional=True)
        self.hidden2tag = nn.Linear(self.hidden_dim, self.tagset_size)
        
        # entry [i,j] is for transition from i to j
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[:, self.START_TAG_IDX] = -10000  # Can't go back to START.  Access .data directly to avoid autograd ?
        self.transitions.data[self.STOP_TAG_IDX, :] = -10000   # Can't leave STOP
 
    
    def _rand_lstm_hidden(self, bsize):
        # ret: (h, c)
        # h: hidden state. shape of (layers_num, batch_size, output_size)
        # c: cell state. shape of (layers_num, batch_size, output_size)
        return (torch.randn(2,bsize,self.hidden_dim//2), torch.randn(2,bsize,self.hidden_dim//2))
    
    def _lstm_features(self, words_input):
        # words_input: words indices in shape of [L, B]
        # ret: Shape of [L, B, T]
        B = words_input.shape[1]
        hidden_state = self._rand_lstm_hidden(B)
        embeddings = self.word_embeddings(words_input)  # [L,B,E]
        lstm_output, hidden_state = self.lstm(embeddings, hidden_state)
        lstm_feat = self.hidden2tag(lstm_output)   # [L,B,H]->[L,B,T]
        return lstm_feat
    
    def _crf_forward(self, features):
        # To compute partition function, i.e., sum up all potentials across every possible tag sequence.
        # features: BiLSTM output in shape of [L,B,T]
        # ret: [B]
        B = features.shape[1]
        fwd_vars = torch.full((B, self.tagset_size, 1), -10000.)  # Shape of [B, T, 1]
        fwd_vars[:, self.START_TAG_IDX, 0] = 0.         # START gets all.
        for feat in features:
            # feat in shape of [B,T]
            # [B,T,1] + [1,T,T] + [B,1,T] = [B,T,T]
            next_vars = fwd_vars + self.transitions.unsqueeze(dim=0) + feat.unsqueeze(dim=1)
            next_vars = log_sum_exp(next_vars, dim=1)     # [B,1,T]
            fwd_vars = torch.transpose(next_vars, 1, 2)   # [B,T,1]
        
        # [B,T,1] + [1,T,1] = [B,T,1]
        terminal_vars = fwd_vars + self.transitions[:, self.STOP_TAG_IDX].unsqueeze(dim=-1).unsqueeze(dim=0)
        return log_sum_exp(terminal_vars, dim=1).squeeze(dim=2).squeeze(dim=1)  # [B,T,1]->[B,T]->[B]
    
    def _crf_decode(self, features):
        # features: [L,B,T]
        B = features.shape[1]
        fwd_vars = torch.full((B, self.tagset_size, 1), -10000.)   # [B,T,1]
        fwd_vars[:, self.START_TAG_IDX, 0] = 0.
        
        backptrs = []
        for feat in features:
            # feat: [B,T]
            # [B,T,1] + [1,T,T] = [B,T,T]
            next_vars = fwd_vars + self.transitions.unsqueeze(dim=0)  # No need add emission here, since emission are the same for the same tag.
            val, idx = torch.max(next_vars, dim=1)  # [B,T]
            backptrs += [idx]
            
            val += feat                  # Add back emission [B,T]
            fwd_vars = val.unsqueeze(2)  # [B,T,1]
        
        
        terminal_vars = fwd_vars.squeeze(2) + self.transitions[:, self.STOP_TAG_IDX]  # [B,T]
        best_scores, best_idx = torch.max(terminal_vars, dim=1)   # [B]
        
        best_paths = [best_idx]
        for ptrs in reversed(backptrs):
            # ptrs in shape of [B,T]
            best_idx = ptrs[range(B), best_idx] # [B]
            best_paths += [best_idx]
            
        start_idx = best_paths.pop()
        
        assert start_idx.tolist() == [self.START_TAG_IDX] * B
        best_paths.reverse()
        return best_scores, best_paths
        
    def _compute_log_potentials(self, features, tags):
        # Compute potentials, including transition and emission, in log-space.
        # features: lstm features for each word in shape of [L,B,T]
        # tags: tag indices for each word in shape of [L,B]
        # ret: scores in shape [B]
        L, B = tags.shape
 
        pre_tags = torch.cat([torch.full((1,B), self.START_TAG_IDX), tags], dim=0)  # (L+1,B)
        next_tags = torch.cat([tags, torch.full((1,B), self.STOP_TAG_IDX)], dim=0)  # (L+1,B)
        
        tran_scores = self.transitions[pre_tags.view(-1), next_tags.view(-1)].view(L+1, B)  # (L+1,B)
        tran_sum = torch.sum(tran_scores, dim=0) # [B]
        
        emis_scores = features.view(L*B, self.tagset_size)[range(L*B), tags.view(-1)].view(L,B) # [L,B]
        emis_sum = torch.sum(emis_scores, dim=0) # [B]
        return tran_sum + emis_sum

    def neg_log_likelihood(self, words_input, tags):
        # words_input: [L,B,E]
        # tags: [L,B]
        # ret:  loss as a scalar
        lstm_feat = self._lstm_features(words_input)
        partition_term = self._crf_forward(lstm_feat)
        potentials = self._compute_log_potentials(lstm_feat, tags)
        
        return torch.mean(partition_term - potentials)
    
    def forward(self, words_input):
        # words_input: words indices in shape of [L,B]
        lstm_feats = self._lstm_features(words_input)
        score, tag_seq = self._crf_decode(lstm_feats)
        return score, tag_seq

# Training

## Dataset

In [None]:

EMBEDDING_DIM = 5
HIDDEN_DIM = 4

# Make up some training data
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
),
(
    "I worked for Shopee an eCommerce company in Singapore as an Engineer".split(),
    "O O O B O O O O B O O O".split()
),
(
    "Yichang is famous for Three Gorges and the dam".split(),
    "B O O O B I O B I".split()
),]


In [None]:
from torch.utils.data import Dataset, DataLoader

class NERDataset(Dataset):
    def __init__(self, dataset):
        # dataset: a list of tuples of (tokens, tags)
        self.dataset = dataset
        
        self.word2idx = {}
        self.tag2idx = {}
        for sentence, tags in self.dataset:
            for word in sentence:
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.word2idx)
            for tag in tags:
                if tag not in self.tag2idx:
                    self.tag2idx[tag] = len(self.tag2idx)
                
        self.word2idx['<pad>'] = len(self.word2idx)
        self.tag2idx['<START>'] = len(self.tag2idx)  # the second to last is start tag
        self.tag2idx['<STOP>'] = len(self.tag2idx)   # the last is stop tag
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        tokens = self.dataset[idx][0]
        tags = self.dataset[idx][1]
        words_indices = [self.word2idx[w] for w in tokens]
        tags_indices = [self.tag2idx[t] for t in tags]
        
        return words_indices, tags_indices
    
    def vocab_size(self):
        return len(self.word2idx)
    
    def word2idx(self, word):
        return self.word2idx[word]
    
    def tag2idx(self, tag):
        return self.tag2idx[tag]
    
    def tagset_size(self):
        return len(self.tag2idx)
    
    def prepare_sequence(self, sent):
        # sent: a list of tokens
        # ret: a 2d tensor for words indices
        idxs = [[self.word2idx[w] for w in sent]]
        return torch.tensor(idxs, dtype=torch.long).T


In [None]:
dataset = NERDataset(training_data)

#WARNNING collate_fn depends on a dataset.
def nerdataset_collate(batch):
    # batch: a list of tuples of (tokens, tags)
    # ret: a tensor for words indices and a tensor for tag indices
    max_len = max([len(tokens) for tokens, _ in batch])
    batched_tokens = []
    batched_tags = []
    for tokens, tags in batch:
        tokens += [dataset.word2idx['<pad>']] * (max_len - len(tokens))
        tags += [dataset.tag2idx['O']] * (max_len - len(tags))
        
        batched_tokens +=[tokens]  # [B,L]
        batched_tags += [tags]     # [B,L]
    
    return torch.tensor(batched_tokens, dtype=torch.int64).T.contiguous(), torch.tensor(batched_tags, dtype=torch.int64).T.contiguous()
 
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=nerdataset_collate)

In [None]:
model = BiLSTM_CRF(dataset.vocab_size(), dataset.tagset_size(), EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

with torch.no_grad():
    precheck_sent = dataset.prepare_sequence(training_data[0][0])
    precheck_tags = torch.tensor([[dataset.tag2idx[t] for t in training_data[0][1]]], dtype=torch.long).T
    print(model(precheck_sent))

In [None]:
# Make sure prepare_sequence from earlier in the LSTM section is loaded
for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data
    for tokens, tags in loader: 
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        optimizer.zero_grad()
        
        # Step 3. Run our forward pass.
        loss = model.neg_log_likelihood(tokens, tags)

        # Step 4. Compute the loss, gradients, and update the parameters by
        # calling optimizer.step()
        loss.backward()
        optimizer.step()

# Check predictions after training
with torch.no_grad():
    precheck_sent = dataset.prepare_sequence(training_data[0][0])
    print(model(precheck_sent))