In [1]:
import denver

In [5]:
from denver.data import DenverDataSource

train_path = '../data/cometv3/train.csv'
test_path = '../data/cometv3/test.csv'

data_source = DenverDataSource.from_csv(train_path=train_path, 
                                        test_path=test_path, 
                                        text_col='text',
                                        intent_col='intent', 
                                        tag_col='tags', 
                                        lowercase=True)

In [6]:
from denver.learners import OnenetLearner

learn = OnenetLearner(mode='training', 
                      data_source=data_source, 
                      rnn_type='lstm', 
                      dropout=0.5,
                      bidirectional=True, 
                      hidden_size=200, 
                      word_embedding_dim=50, 
                      word_pretrained_embedding='vi-glove-50d', 
                      char_encoder_type='cnn', 
                      char_embedding_dim=3, 
                      num_filters=128, 
                      ngram_filter_sizes=[3], 
                      conv_layer_activation='relu')

In [7]:
from denver.trainers.trainer import ModelTrainer

trainer = ModelTrainer(learn=learn)
trainer.train(base_path='./models/onenet/', 
              model_file='denver-onenet.tar.gz', 
              learning_rate=0.001, 
              batch_size=64, 
              num_epochs=200)

0it [00:00, ?it/s]


➖➖➖➖➖➖➖➖➖➖ TRAINING ➖➖➖➖➖➖➖➖➖➖

2020-12-17 10:51:16,536 INFO  denver.data.dataset_reader:92 - Reading instances from lines in file at: /tmp/tmpxmzb9ha2/train.csv


6047it [00:00, 14938.61it/s]
0it [00:00, ?it/s]

2020-12-17 10:51:16,941 INFO  denver.data.dataset_reader:92 - Reading instances from lines in file at: /tmp/tmpxmzb9ha2/test.csv


1068it [00:00, 6326.45it/s]
7115it [00:00, 26097.78it/s]
22386it [00:00, 229908.45it/s]
Epoch   0: main_score: 0.0000, loss: 3.7564 |: 100%|██████████| 95/95 [00:15<00:00,  6.02it/s]
Epoch   1: main_score: 0.0231, loss: 2.3874 |: 100%|██████████| 95/95 [00:15<00:00,  5.94it/s]


2020-12-17 10:51:50,123 INFO  denver.learners.onenet_learner:295 - Path to the saved model: models/onenet/denver-onenet.tar.gz

⏰  The trained time: 0:00:33.604110



In [9]:
## evaluate 

learn = OnenetLearner(mode='inference', model_path='./models/onenet/denver-onenet.tar.gz')

data_path = '../data/cometv3/test.csv'

metrics = learn.evaluate(data=data_path, 
                         text_col='text', 
                         intent_col='intent', 
                         tag_col='tags',
                         lowercase=True)

2020-12-17 10:52:18,425 INFO  denver.learners.onenet_learner:677 - Reading evaluation data from /tmp/tmpbcpjep33/data.csv


0it [00:00, ?it/s]

2020-12-17 10:52:18,426 INFO  denver.data.dataset_reader:92 - Reading instances from lines in file at: /tmp/tmpbcpjep33/data.csv


1068it [00:00, 16299.23it/s]

2020-12-17 10:52:18,492 INFO  denver.learners.onenet_learner:686 - Evaluating...
2020-12-17 10:52:18,493 INFO  denver.learners.onenet_learner:555 - Iterating over dataset



100%|██████████| 17/17 [00:00<00:00, 30.52it/s]


In [11]:
from denver.utils.print_utils import view_table

view_table(metrics)

+--------+------------+
|  loss  | main_score |
+--------+------------+
| 1.8119 |   0.0756   |
+--------+------------+
Intent results: 
+----------+----------+-----------+--------+
| accucary | f1-score | precision | recall |
+----------+----------+-----------+--------+
|  0.6854  |  0.6464  |  0.6684   | 0.6854 |
+----------+----------+-----------+--------+
Intent detailed results: 
                    precision    recall  f1-score   support

             agree       0.62      0.70      0.65        23
       ask_confirm       0.62      0.94      0.75       316
        ask_is_bot       1.00      0.43      0.61        23
              deny       0.98      0.96      0.97       179
             greet       0.89      0.70      0.78        23
            inform       0.70      0.57      0.63       223
request#age_of_use       0.38      0.13      0.19        23
     request#brand       0.58      0.26      0.36        27
     request#color       1.00      0.26      0.42        19
 request#gu

In [12]:
from pprint import pprint

## inference a sample

prediction = learn.predict(sample="xe day con mau vàng k sh", lowercase=True)
pprint(prediction)

{'intent': 'ask_confirm',
 'intent_probs': array([8.472486e-01, 1.390546e-01, 2.941240e-03, 8.740714e-04, 4.832490e-04, 7.965628e-06, 3.421558e-04, 1.891766e-03,
       2.839373e-05, 7.605435e-05, 2.445794e-04, 3.626556e-04, 1.484683e-04, 3.453332e-04, 7.620518e-04, 2.558392e-03,
       1.409377e-03, 1.057163e-03, 1.639659e-04], dtype=float32),
 'mask': array([1, 1, 1, 1, 1, 1, 1]),
 'nlu': {'ask_confirm': []},
 'span_tags': [],
 'tag_logits': array([[ 5.203199,  3.418806,  1.580193,  0.336525, ..., -5.362263, -4.325233, -6.058815, -5.410845],
       [ 3.871637,  1.896619,  1.413062,  0.336331, ..., -5.042706, -4.29088 , -5.715288, -5.234156],
       [ 2.919298,  0.356529,  0.793495,  0.193774, ..., -4.422326, -4.09605 , -5.016703, -4.827662],
       [ 1.816658, -1.193515, -0.293841,  0.048384, ..., -3.35841 , -3.424273, -3.803444, -3.96379 ],
       [ 2.115446, -2.421565, -1.239655, -0.452128, ..., -3.155025, -3.254362, -3.146738, -3.738115],
       [ 7.332287, -2.590415, -2.130936, -

In [13]:
from pprint import pprint

## inference a sample

output = learn.process(sample="xe day con mau vàng k sh", lowercase=True)
pprint(output)

{'entities': [],
 'intent': {'confidence': 0.8472485542297363, 'name': 'ask_confirm'},
 'text': 'xe day con mau vàng k sh'}


In [14]:
## Get predictions on a Dataframe or path to .csv

data_path = '../data/cometv3/test.csv'

data_df = learn.predict_on_df(data=data_path, 
                              text_cols='text', 
                              intent_cols=None, 
                              tag_cols=None, 
                              lowercase=True)

data_df.head()

2020-12-17 10:52:56,791 INFO  denver.learners.onenet_learner:770 - Get-prediction...


100%|██████████| 1068/1068 [00:05<00:00, 191.58it/s]


Unnamed: 0,text,intent_pred,tag_pred
0,mẫu này thì là của hq phải ko ạ,ask_confirm,O O O O O O O O O
1,nôi này cho bé đến 2 tuổi không,inform,O O O O O O O O
2,có màu xanh lá k,ask_confirm,O B-ask_confirm#color I-ask_confirm#color I-as...
3,bên bạn có cọ bình sữa ko,inform,O O O O O O O
4,bé nhà mình 5.5 tháng chưa biết ngồi thì có dù...,ask_confirm,O O O O O O O O O O O O O O
