In [21]:

import sys
sys.path.append("..")
from pathlib import Path

import tensorflow as tf
import functools
import numpy as np

from models.lstm_crf.main import input_fn, fwords, ftags
from models.utils import tags_dictionaries, words_dictionaries

# from models.lstm_crf.main import input_fn


DATADIR = "../data/processed_data/gmb"
def fwords(name):
    return str(Path(DATADIR, "{}.sentences.csv".format(name)))

def ftags(name):
    return str(Path(DATADIR, "{}.labels.csv".format(name)))

def parse_fn(line_words, line_tags):
    words = np.array([word2idx.get(w, 0) for w in line_words.strip().split()])
    tags = np.array([tag2idx[t] for t in line_tags.strip().split()])
    assert len(words) == len(tags), "Words and tags lengths don't match"
    return words, tags

def generator_fn(words, tags):
    with Path(words).open('r') as f_words, Path(tags).open('r') as f_tags:
        for line_words, line_tags in zip(f_words, f_tags):
            yield parse_fn(line_words, line_tags)

def input_fn(words, tags, params=None, shuffle_and_repeat=False):
    params = params if params is not None else {}
    output_signature = (
        tf.TensorSpec(shape=([None]), dtype=tf.int32),
        tf.TensorSpec(shape=([None]), dtype=tf.int32))

    dataset = tf.data.Dataset.from_generator(
        functools.partial(generator_fn, words, tags),
        output_signature=output_signature
    )

    if shuffle_and_repeat:
        dataset = dataset.shuffle(params['buffer']).repeat(params['epochs'])

    shapes = (tf.TensorShape([None]),tf.TensorShape([None]))
    dataset = (dataset
               .padded_batch(batch_size=32,
                             padded_shapes=([params["max_len"]], [params["max_len"]]),
                             padding_values=(params['vocab_size']-1,params['pad_index'])
                            )
               .prefetch(1))
    return dataset

params = {
    "dim": 100,
    "dropout": 0.2,
    "max_len": 60,
    "epochs": 3,
    "batch_size": 32,
    "buffer": 15000,
    "lstm_size": 100,
    "words": str(Path(DATADIR, "vocabulary.txt")),
    "tags": str(Path(DATADIR, "tags.txt")),
    "embeddings_dim": 50
}

tag2idx, idx2tag , tags_len= tags_dictionaries(params)
word2idx, idx2word, vocab_size = words_dictionaries(params)
params['vocab_size']=vocab_size
params['pad_index']=tag2idx['O']
params["labels_size"]=tags_len
    
functools.partial(input_fn, fwords('test'), ftags('test'), params, shuffle_and_repeat=True)()

model=tf.keras.models.load_model("../models/lstm_crf/results/model")
valid_dataset = functools.partial(input_fn, fwords('test'), ftags('test'), params, shuffle_and_repeat=True)()
predictions = model.predict(valid_dataset)

In [33]:
for a in valid_dataset:
    print(a)
    break

(<tf.Tensor: shape=(32, 60), dtype=int32, numpy=
array([[   0,   23,  896, ..., 9232, 9232, 9232],
       [   0, 3065,   36, ..., 9232, 9232, 9232],
       [1636, 2365,   86, ..., 9232, 9232, 9232],
       ...,
       [1636, 4910,  754, ..., 9232, 9232, 9232],
       [   0,   36,  345, ..., 9232, 9232, 9232],
       [ 293,  294,  592, ..., 9232, 9232, 9232]], dtype=int32)>, <tf.Tensor: shape=(32, 60), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 7, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [5, 5, 0, ..., 0, 0, 0]], dtype=int32)>)


In [24]:
predictions[0].shape

(60, 10)

In [27]:
print(predictions[0][1])
print(np.argmax(predictions[0][1]))

[9.7509730e-01 9.5932992e-05 3.5371160e-04 8.2796952e-03 3.8849862e-04
 8.1546344e-03 7.7372818e-04 4.0629865e-03 2.6539089e-03 1.3964865e-04]
0


In [28]:
print(predictions[0][0])

[9.1956699e-01 2.6130176e-04 1.0085534e-03 2.7530482e-02 1.0763111e-03
 2.5814289e-02 2.2129915e-03 1.3319554e-02 8.8139372e-03 3.9564437e-04]


In [35]:
prediction = predictions[0]

In [36]:
words, labelprediction

array([[9.19566989e-01, 2.61301757e-04, 1.00855343e-03, 2.75304820e-02,
        1.07631111e-03, 2.58142892e-02, 2.21299147e-03, 1.33195538e-02,
        8.81393719e-03, 3.95644369e-04],
       [9.75097299e-01, 9.59329918e-05, 3.53711599e-04, 8.27969518e-03,
        3.88498622e-04, 8.15463439e-03, 7.73728185e-04, 4.06298647e-03,
        2.65390892e-03, 1.39648648e-04],
       [9.89059210e-01, 4.60408082e-05, 1.67081540e-04, 3.32809379e-03,
        1.87030208e-04, 3.53970798e-03, 3.82059545e-04, 1.96453626e-03,
        1.26061321e-03, 6.57335331e-05],
       [9.98252094e-01, 9.93773756e-06, 3.26682239e-05, 4.43168130e-04,
        3.90208188e-05, 6.25927525e-04, 7.95064188e-05, 2.89988442e-04,
        2.14287284e-04, 1.33123231e-05],
       [9.77243483e-01, 7.03303594e-05, 2.91842240e-04, 3.84338433e-03,
        3.50762217e-04, 6.30280096e-03, 8.27198324e-04, 6.61455747e-03,
        4.34702914e-03, 1.08663546e-04],
       [9.95280325e-01, 1.79217841e-05, 7.76058441e-05, 7.29622319e-04,
   