In [3]:
import math
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
batch_size = 128
valid_path = 'valid.txt'
test_psth = 'test.txt'
def make_batch1(train_path, word2number_dict, n_step):
    def word2number(n):
        try:
            return word2number_dict[n]
        except:
            return 1  
    all_input_batch = []
    all_target_batch = []
    text = open(train_path, 'r', encoding='utf-8')
    input_batch = []
    target_batch = []
    for sen in text:
        word = sen.strip().split(" ") 
        if len(word) <= n_step:   
            word = ["<pad>"]*(n_step+1-len(word)) + word
        for word_index in range(len(word)-n_step):
            input = [word2number(n) for n in word[word_index:word_index+n_step]]  
            target = word2number(word[word_index+n_step])  
            input_batch.append(input)
            target_batch.append(target)
            if len(input_batch) == batch_size:
                all_input_batch.append(input_batch)
                all_target_batch.append(target_batch)
                input_batch = []
                target_batch = []
    return all_input_batch, all_target_batch

In [5]:
def give_valid(word2number_dict, n_step):
    all_input_batch, all_target_batch = make_batch1(valid_path, word2number_dict, n_step)
    all_input_batch = torch.LongTensor(all_input_batch) 
    all_target_batch = torch.LongTensor(all_target_batch)
    return  all_input_batch, all_target_batch

def give_test(word2number_dict, n_step):
    all_input_batch, all_target_batch = make_batch1(test_psth, word2number_dict, n_step)
    all_input_batch = torch.LongTensor(all_input_batch)  
    all_target_batch = torch.LongTensor(all_target_batch)
    return all_input_batch, all_target_batch

In [6]:
device = torch.device("cpu")

In [7]:
def make_batch(train_path, word2number_dict, batch_size, n_step):
    def word2number(n):
        try:
            return word2number_dict[n]
        except:
            return 1   #<unk_word>

    all_input_batch = []
    all_target_batch = []

    text = open(train_path, 'r', encoding='utf-8') #open the file

    input_batch = []
    target_batch = []
    for sen in text:
        word = sen.strip().split(" ")  # space tokenizer

        if len(word) <= n_step:   #pad the sentence
            word = ["<pad>"]*(n_step+1-len(word)) + word

        for word_index in range(len(word)-n_step):
            input = [word2number(n) for n in word[word_index:word_index+n_step]]  # create (1~n-1) as input
            target = word2number(word[word_index+n_step])  # create (n) as target, We usually call this 'casual language model'
            input_batch.append(input)
            target_batch.append(target)

            if len(input_batch) == batch_size:
                all_input_batch.append(input_batch)
                all_target_batch.append(target_batch)
                input_batch = []
                target_batch = []

    return all_input_batch, all_target_batch

In [8]:
def make_dict(train_path):
    text = open(train_path, 'r', encoding='utf-8')  
    word_list = set()  
    for line in text:
        line = line.strip().split(" ")
        word_list = word_list.union(set(line))
    word_list = list(sorted(word_list))  
    word2number_dict = {w: i+2 for i, w in enumerate(word_list)}
    number2word_dict = {i+2: w for i, w in enumerate(word_list)}
    word2number_dict["<pad>"] = 0
    number2word_dict[0] = "<pad>"
    word2number_dict["<unk_word>"] = 1
    number2word_dict[1] = "<unk_word>"
    return word2number_dict, number2word_dict

In [26]:

class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()
        self.C = nn.Embedding(n_class, embedding_dim=emb_size)

        '''define the parameter of RNN'''
        '''begin'''
        
        #gt
        self.W_ig = nn.Linear(emb_size, n_hidden, bias=False) 
        self.W_hg = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_g = nn.Parameter(torch.ones([n_hidden]))
        #gi
        self.W_xi = nn.Linear(emb_size, n_hidden, bias=False)
        self.W_hi = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_i = nn.Parameter(torch.ones([n_hidden]))
        #gf
        self.W_xf = nn.Linear(emb_size, n_hidden, bias=False) 
        self.W_hf = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_f = nn.Parameter(torch.ones([n_hidden]))
        #go
        self.W_xo = nn.Linear(emb_size, n_hidden, bias=False)  
        self.W_ho = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_o = nn.Parameter(torch.ones([n_hidden]))
        


        self.W_ig_t = nn.Linear(n_hidden, n_hidden, bias=False)  
        self.W_hg_t = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b1_t = nn.Parameter(torch.ones([n_hidden]))
        
        self.W_xi_t = nn.Linear(n_hidden, n_hidden, bias=False)  
        self.W_hi_t = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_i_t = nn.Parameter(torch.ones([n_hidden]))
        
        self.W_xf_t = nn.Linear(n_hidden, n_hidden, bias=False)  
        self.W_hf_t = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_f_t = nn.Parameter(torch.ones([n_hidden]))
        
        self.W_xo_t = nn.Linear(n_hidden, n_hidden, bias=False)  
        self.W_ho_t = nn.Linear(n_hidden, n_hidden, bias=False)
        self.b_o_t = nn.Parameter(torch.ones([n_hidden]))

        self.W = nn.Linear(n_hidden, n_class, bias=False)
        self.b = nn.Parameter(torch.ones([n_class]))

    def forward(self, X):
        X = self.C(X)
        X = X.transpose(0, 1)  # X : [n_step, batch_size, n_class]
        sample_size = X.size()[1]  
        c_t,c_t_t,h_t,h_t_t =(
            torch.zeros(batch_size,n_hidden).to(X.device),
            torch.zeros(batch_size,n_hidden).to(X.device),
            torch.zeros(batch_size,n_hidden).to(X.device),
            torch.zeros(batch_size,n_hidden).to(X.device)
            )
        outin= []

        for x in X:
            g_t = torch.tanh(self.W_ig(x) + self.W_hg(h_t) + self.b_g)   
            i_t = torch.sigmoid(self.W_xi(x) + self.W_hi(h_t) + self.b_i) 
            f_t = torch.sigmoid(self.W_xf(x) + self.W_hf(h_t) + self.b_f)
            o_t = torch.sigmoid(self.W_xo(x) + self.W_ho(h_t) + self.b_o)  
            c_t = c_t * f_t + g_t*i_t  
            h_t = o_t*torch.tanh(c_t)
            outin.append(h_t)

        for y in outin:
            g_t_t = torch.tanh(self.W_ig_t(y) + self.W_hg_t(h_t_t) + self.b1_t)  
            i_t_t = torch.sigmoid(self.W_xi_t(y) + self.W_hi_t(h_t_t) + self.b_i_t)  
            f_t_t = torch.sigmoid(self.W_xf_t(y) + self.W_hf_t(h_t_t) + self.b_f_t)  
            o_t_t = torch.sigmoid(self.W_xo_t(y) + self.W_ho_t(h_t_t) + self.b_o_t)  
            c_t_t = c_t_t * f_t + g_t_t * i_t_t  
            h_t_t = o_t*torch.tanh(c_t_t)


        model = self.W(h_t_t) + self.b

        return model


In [27]:
def train():
    model = TextLSTM()
    model.to(device)
    print(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learn_rate)
    # Training
    batch_number = len(all_input_batch)
    for epoch in range(all_epoch):
        count_batch = 0
        for input_batch, target_batch in zip(all_input_batch, all_target_batch):
            optimizer.zero_grad()
            output = model(input_batch)
            loss = criterion(output, target_batch)
            if (count_batch + 1) % 100== 0:
                print('Epoch:', '%04d' % (epoch + 1), 'Batch:', '%02d' % (count_batch + 1), f'/{batch_number}',
                      'loss =', '{:.6f}'.format(loss))
            loss.backward()
            optimizer.step()
            count_batch += 1
        all_valid_batch, all_valid_target = give_valid(word2number_dict, n_step)
        all_valid_batch.to(device)
        all_valid_target.to(device)
        total_valid = len(all_valid_target)*128
        with torch.no_grad():
            total_loss = 0
            count_loss = 0
            for valid_batch, valid_target in zip(all_valid_batch, all_valid_target):
                valid_output = model(valid_batch)
                valid_loss = criterion(valid_output, valid_target)
                total_loss += valid_loss.item()
                count_loss += 1
          
            print(f'Valid {total_valid} samples after epoch:', '%04d' % (epoch + 1), 'loss =',
                  '{:.6f}'.format(total_loss / count_loss))
        if (epoch+1) % save_checkpoint_epoch == 0:
            torch.save(model, f'nnlm_model_epoch{epoch+1}.ckpt')
            print('**')

In [28]:
def test(select_model_path):
    model = torch.load(select_model_path, map_location="cpu")  
    all_test_batch, all_test_target = give_test(word2number_dict, n_step)
    total_test = len(all_test_target)*batch_size
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    count_loss = 0
    for test_batch, test_target in zip(all_test_batch, all_test_target):
        test_output = model(test_batch)
        test_loss = criterion(test_output, test_target)
        total_loss += test_loss.item()
        count_loss += 1

    print(f"Test {total_test} samples with {select_model_path}")
    print('loss =','{:.6f}'.format(total_loss / count_loss))

In [29]:
if __name__ == '__main__':
    n_step = 2 
    n_hidden = 2 
    m = 2 
    batch_size = 512 
    learn_rate = 0.001
    all_epoch = 10
    save_checkpoint_epoch = 5 
    emb_size=32
    train_path = 'train.txt' 

    word2number_dict, number2word_dict = make_dict(train_path) 
    print("The size of the dictionary is", len(word2number_dict))

    n_class = len(word2number_dict) 

    all_input_batch, all_target_batch = make_batch(train_path, word2number_dict, batch_size, n_step)  
    print("The number of the train batch is", len(all_input_batch))

    all_input_batch = torch.LongTensor(all_input_batch).to(device)  
    all_target_batch = torch.LongTensor(all_target_batch).to(device)

    print("Train")
    train()
    print("Test")
    select_model_path = "nnlm_model_epoch5.ckpt"
    test(select_model_path) 
    select_model_path = "nnlm_model_epoch10.ckpt"
    test(select_model_path) 
    print('The end')

The size of the dictionary is 7613
The number of the train batch is 159
Train
TextLSTM(
  (C): Embedding(7613, 32)
  (W_ig): Linear(in_features=32, out_features=2, bias=False)
  (W_hg): Linear(in_features=2, out_features=2, bias=False)
  (W_xi): Linear(in_features=32, out_features=2, bias=False)
  (W_hi): Linear(in_features=2, out_features=2, bias=False)
  (W_xf): Linear(in_features=32, out_features=2, bias=False)
  (W_hf): Linear(in_features=2, out_features=2, bias=False)
  (W_xo): Linear(in_features=32, out_features=2, bias=False)
  (W_ho): Linear(in_features=2, out_features=2, bias=False)
  (W_ig_t): Linear(in_features=2, out_features=2, bias=False)
  (W_hg_t): Linear(in_features=2, out_features=2, bias=False)
  (W_xi_t): Linear(in_features=2, out_features=2, bias=False)
  (W_hi_t): Linear(in_features=2, out_features=2, bias=False)
  (W_xf_t): Linear(in_features=2, out_features=2, bias=False)
  (W_hf_t): Linear(in_features=2, out_features=2, bias=False)
  (W_xo_t): Linear(in_feature