In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
from keras.preprocessing import text
from keras.utils import np_utils
from keras.preprocessing import sequence
from nltk.corpus import gutenberg
from string import punctuation
import nltk
from underthesea import word_tokenize
from collections import defaultdict

In [93]:

def cleaning(raw_texts):
    '''
    Clean other punct, and other simple processs
    '''
    data = []
    for sent in raw_texts:
        sent = sent.replace('\n','')
        sent = sent.replace(':',',')
        sent = sent.replace('!','.')
        sent = sent.replace('?','.')
        sent = sent.replace(';',',')
        sent = sent.replace('"','')
        sent = sent.replace(')','')
        sent = sent.replace('(','')
        sent = sent.replace('“','')
        sent = sent.replace('”','')
        sent = sent.replace('-','')
        sent = sent.replace('_','')
        sent = sent.replace('+','')
        sent = sent.replace('=','')
        sent = sent.replace('[','')
        sent = sent.replace(']','')
        sent = sent.replace('{','')
        sent = sent.replace('}','')
        sent = sent.replace('*','')
        sent = sent.replace('&','')
        sent = sent.replace('^','')
        sent = sent.replace('%','')
        sent = sent.replace('$','')
        sent = sent.replace('#','')
        sent = sent.replace('@','')
        sent = sent.replace('!','')
        sent = sent.replace('`','')
        sent = sent.replace('~','')
        sent = sent.replace('/','')
        sent = sent.replace('|','')
        sent = sent.replace('…','')
        sent = sent.replace(',.','')
        if count_comma(sent) >= 2:
            if count_word(sent) >= 20 and count_word(sent) <= 50:
                sent = word_tokenize(sent)
                sent = ' '.join(sent)
                sent = sent+'\n'
                data.append(sent)
    return data

In [2]:
torch.manual_seed(16)

<torch._C.Generator at 0x7f94f3c23290>

In [111]:
with open('./data/Data_byADuc.txt', 'r') as f:
    lines = f.read().splitlines()

In [100]:
def count_comma(sent):
    return sent.count(",")

def count_word(string):
    return(len(string.split()))

def cleaning(raw_texts):
    '''
    Clean other punct, and other simple processs
    '''
    data = []
    for sent in raw_texts:
        sent = sent.replace('\n','')
        sent = sent.replace(':',',')
        sent = sent.replace('!','.')
        sent = sent.replace('?','.')
        sent = sent.replace(';',',')
        sent = sent.replace('"','')
        sent = sent.replace(')','')
        sent = sent.replace('(','')
        sent = sent.replace('“','')
        sent = sent.replace('”','')
        sent = sent.replace('-','')
        sent = sent.replace('_','')
        sent = sent.replace('+','')
        sent = sent.replace('=','')
        sent = sent.replace('[','')
        sent = sent.replace(']','')
        sent = sent.replace('{','')
        sent = sent.replace('}','')
        sent = sent.replace('*','')
        sent = sent.replace('&','')
        sent = sent.replace('^','')
        sent = sent.replace('%','')
        sent = sent.replace('$','')
        sent = sent.replace('#','')
        sent = sent.replace('@','')
        sent = sent.replace('!','')
        sent = sent.replace('`','')
        sent = sent.replace('~','')
        sent = sent.replace('/','')
        sent = sent.replace('|','')
        sent = sent.replace('…','')
        sent = sent.replace(',.','')
        if count_comma(sent) >= 2:
            if count_word(sent) >= 20 and count_word(sent) <= 50:
                sent = word_tokenize(sent)
                sent = ' '.join(sent)
                sent = sent+'\n'
                data.append(sent)
    return data



def create_label(text):

    '''
    Take a string -> intext and label
    '''
    tokens = word_tokenize(text)
    words = []
    ids_punct = {',':[], '.':[]}
    i = 0
    for token in tokens:
        if token not in ids_punct.keys():
            words.append(token)
            i+=1
        else:
            ids_punct[token].append(i-1)

    label = [0]*len(words)
    for pun, ids in ids_punct.items():
        for index in ids:
            label[index] = 1 if pun == ',' else 2
    
    in_text = '<fff>'.join(words)
    return in_text, label



def preprocessing_train_data(RAW_PATH = './data/Data_byADuc.txt', IN_TEXT_PATH = './demo_data/text.txt', 
                             LABEL_PATH = './demo_data/label.txt'):
    # start processing
    with open(RAW_PATH, 'r') as f:
        texts = f.read().splitlines()


    texts = cleaning(texts)
    texts, labels = [], []
    for text in lines:
        in_text, label = create_label(text)
        texts.append(in_text)
        labels.append(label)


    with open(IN_TEXT_PATH, 'w') as f:
        for text in texts:
            f.write(text)
            f.write('\n')

    with open(LABEL_PATH, 'w') as f:
        for label in labels :
            label = [str(ele) for ele in label]
            label = ' '.join(label)
            f.write(label)
            f.write('\n')


In [112]:
mylines = lines[50000:51000]

In [113]:
lines = cleaning(mylines)
texts, labels = [], []
for text in lines:
    in_text, label = create_label(text)
    texts.append(in_text)
    labels.append(label)

In [114]:
with open('./demo_data/testtext.txt', 'w') as f:
    for text in texts:
        f.write(text)
        f.write('\n')

with open('./demo_data/testlabel.txt', 'w') as f:
    for label in labels :
        label = [str(ele) for ele in label]
        label = ' '.join(label)
        f.write(label)
        f.write('\n')

In [64]:
class MyDataset(Dataset):
    def __init__(self, data_path, label_path):
        
        self.data_path = data_path
        self.label_path = label_path
        
        with open(self.data_path, 'r') as f:
            sents = f.read().splitlines()
        
        
        self.sents = [sent.split('<fff>') for sent in sents]
        
        self.word2id = dict()
        i = 1
        for sent in self.sents:
            for word in sent:
                if word not in self.word2id:
                    self.word2id[word] = i
                    i+=1
        self.word2id['<pad>'] = 0
        self.id2word = {v:k for (k,v) in self.word2id.items()}
        
        self.vocab_size = len(self.word2id)
        
        self.sents = [[self.word2id[word] for word in sent] for sent in self.sents]
        
        self.sents = sequence.pad_sequences(self.sents, maxlen=32, padding="post")
                
        
        with open(self.label_path, 'r') as f:
            labels = f.read().splitlines()
        
        self.labels = [list(map(int,label.split())) for label in labels]
        self.labels = sequence.pad_sequences(self.labels, maxlen=32, padding="post", value=3)
        
        
    def __getitem__(self, index):
        
        return {'data':self.sents[index], 'label':self.labels[index]}
    
    def __len__(self):
        return len(self.labels)

In [80]:
traindataset = MyDataset(data_path ='./demo_data/text.txt', label_path ='./demo_data/label.txt')

## testing with batchsize dataset

In [81]:
def dataset_batch_iter(dataset, batch_size):
    b_words = []
    b_labels = []
    for data in dataset:
        b_words.append(data['data'])
        b_labels.append(data['label'])
        
        if len(b_words) == batch_size:
            yield {'data':np.array(b_words, dtype=int), 'label':np.array(b_labels, dtype=int)}
            b_words, b_labels = [], []
    

In [83]:
for batch, data in enumerate(dataset_batch_iter(traindataset, 5)):
    print(batch)
    print(data['data'])
    print(data['label'])
    break
    

0
[[  3   4   5   6   7   8   9  10  11   6  12  13  14  15   8  16  17  18
    4  19  20  21  22  23  24  25  26  27  28  29  30   6]
 [ 31   8  32  10  33  34  35  36  37  38  39  40   2  38  12  41  42  43
   44  43  45  46  47  38  48  41  42  49   0   0   0   0]
 [ 50  51  52  53  45  54  55  56  57  16  58  59  60  61  45  62  63  17
   64  65  66  67  48  68  69   0   0   0   0   0   0   0]
 [ 71  72  73  74  75  12  72  76  77  78  12  79  80  81  82  83  84  85
   86  12  87  88  89  82  90  65  91  92  93  27  94  95]
 [ 97  98  99 100  45 101  20  65 102 103 104  33 105 106 107 108 109  98
  110   8 111 106  31  10 112 113  45 114 115 116   8 117]]
[[1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 2]
 [0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 3 3 3 3]
 [0 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 2 3 3 3 3 3 3 3]
 [1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 2]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 

## Model

In [85]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size,
                 hidden_dim, n_layers):
        
        super(RNNModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.output_size = output_size
        self.embedding_size = embedding_size
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embedding_size)
        
        
        self.rnn = nn.RNN(embedding_size, hidden_dim, n_layers, batch_first=True)   
        
        self.fc = nn.Linear(hidden_dim, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x, hidden):
        embedded = self.embedding(x)
        
        
        output, hidden = self.rnn(embedded, hidden)
        
        output = self.fc(output)
        prob = self.softmax(output)
        
        
        return prob, hidden
    
    def init_hidden(self, batch_size):
       
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        return hidden

In [86]:
def train_epoch(model, train_dataset, batch_size):
    # training
    model.train()
    hidden = model.init_hidden(batch_size)
    train_loss = 0.
    for batch, data in enumerate(dataset_batch_iter(train_dataset, batch_size)):
        input_tensor = torch.Tensor(data['data']).type(torch.LongTensor)
        target_tensor = torch.Tensor(data['label']).type(torch.LongTensor)
        
        optimizer.zero_grad()
        output, hidden = model(input_tensor, hidden)
        hidden = Variable(hidden.data, requires_grad=True)

        loss = nll_loss(output.view(-1, num_categories), target_tensor.view(-1))
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    return train_loss

In [87]:
def infere(model, input_tensor):
    # infering
    model.eval()
    batch_size = input_tensor.shape[0]
    hidden = model.init_hidden(batch_size)
    output, _ = model(input_tensor, hidden)
    prediction = output.argmax(dim=-1)
    return prediction

In [88]:
def train(epochs, train_dataset, test_dataset):
    
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_dataset, batch_size)
        
        if epoch % 10 == 0:
            test_loss, accuracy = test(model, test_dataset, batch_size)
            print(f"Epoch {epoch} --train loss {train_loss} -- test loss {test_loss}-- test acc {accuracy}")

In [89]:
def restore(tokens, punct):
    id2word = dataset.id2word
    convert = {0:'', 1:',', 2:'.', 3:''}
    seq = [id2word[token]+convert[punct[i]] for i, token in enumerate(tokens)]
    seq = ' '.join(seq)
    return seq

In [90]:
train(101, traindataset, traindataset)

Epoch 0 --train loss 2916.7114140987396 -- test loss 2900.8115725517273-- test acc 0


KeyboardInterrupt: 

In [None]:
for batch, data in enumerate(dataset_batch_iter(dataset, 1)):
    data, label = data['data'], data['label']
    tokens = list(data[0])
    label = list(label[0])
    data = torch.Tensor(data).type(torch.LongTensor)
    pred = infere(model, data)
    pred = np.array(pred[0])
    print("true sent:")
    print(restore(tokens, label))
    print("--------")
    print("prediction:")
    print(restore(tokens, pred))
    break