Skip to content

Commit

Permalink
GH-61: added highlighter over paragraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Duncan Blythe committed Aug 13, 2018
1 parent d54f4cc commit 61ae2d8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 13 deletions.
3 changes: 2 additions & 1 deletion flair/visual/__init__.py
@@ -1 +1,2 @@
from .manifold import tSNE, uMap, show, prepare_word_embeddings, prepare_char_embeddings, word_contexts, char_contexts
from .manifold import tSNE, show, prepare_word_embeddings, prepare_char_embeddings, word_contexts, char_contexts
from .activations import Highlighter
66 changes: 66 additions & 0 deletions flair/visual/activations.py
@@ -0,0 +1,66 @@
import numpy


class Highlighter:
def __init__(self):
self.color_map = [
"#ff0000",
"#ff4000",
"#ff8000",
"#ffbf00",
"#ffff00",
"#bfff00",
"#80ff00",
"#40ff00",
"#00ff00",
"#00ff40",
"#00ff80",
"#00ffbf",
"#00ffff",
"#00bfff",
"#0080ff",
"#0040ff",
"#0000ff",
"#4000ff",
"#8000ff",
"#bf00ff",
"#ff00ff",
"#ff00bf",
"#ff0080",
"#ff0040",
"#ff0000",
]

def highlight(self, activation, text, file_='resources/data/highlight.html'):

activation = activation.detach().numpy()

step_size = (max(activation) - min(activation)) / len(self.color_map)

lookup = numpy.array(list(
numpy.arange(min(activation), max(activation), step_size)
))

colors = []

for i, act in enumerate(activation):

try:
colors.append(
self.color_map[numpy.where(act > lookup)[0][-1]]
)
except IndexError:
colors.append(len(self.color_map) - 1)

str_ = ''

for i, (char, color) in enumerate(zip(list(text), colors)):
str_ += self._render(char, color)

with open(file_, 'w') as f:
f.write(str_)

def _render(self, char, color):
return '<span style="background-color: {}">{}</span>'.format(color, char)


12 changes: 0 additions & 12 deletions flair/visual/manifold.py
Expand Up @@ -89,18 +89,6 @@ def __init__(self):
TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)


class uMap(_Transform):
def __init__(self):

super().__init__()

self.transform = UMAP(
n_neighbors = 5,
min_dist = 0.3,
metric = 'correlation',
)


def show(X, contexts):
import matplotlib.pyplot
import mpld3
Expand Down
15 changes: 15 additions & 0 deletions tests/test_visual.py
Expand Up @@ -176,6 +176,21 @@ def test_uni(self):
show(reduced, contexts)


class TestHighlighter(unittest.TestCase):
def test(self):

i = numpy.random.choice(2048)

with open('resources/data/snippet.txt') as f:
sentences = [x for x in f.read().split('\n') if x]

embeddings = CharLMEmbeddings('news-forward')

features = embeddings.lm.get_representation(sentences[0]).squeeze()

Highlighter().highlight(features[:, i], sentences[0])



if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 61ae2d8

Please sign in to comment.