# Fruit-Fly-Net demo 

In this notebook I'll show how to implement of the fruit fly network (Fruit-Fly-Net) introduced in [Can a Fruit Fly Learn Word Embeddings](https://arxiv.org/abs/2101.06887) to create word embeddings from Pokedex entries. 

**Note: This notebook requires a GPU runtime for training** (`Runtime > Change runtime type > Hardware accelerator > GPU`)

## Installations and imports
**Note**: Restart the runtime after installing the packages

In [None]:
pip install -U einops gradio numpy spacy git+https://github.com/Ramos-Ramos/fruit-fly-net

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
from einops import rearrange
import cupy as cp
import cupy as xp
import numpy as np
import gradio as gr
import pandas as pd
from cupyx.scipy.sparse import csr_matrix, vstack
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.notebook import tqdm

from collections import Counter, OrderedDict
import pickle

# from fruit_fly_net import FruitFlyNet, bio_hash_loss

## Preprocess and Prepare Dataset

The corpus from which we'll create word embeddings consists of several Pokedex entries. To create word emebddings, we need to get a list of words to begin with. We can do that by tokenizing our corpus, or converting the corpus into a vocabulary of words, or "tokens". For Fruit-Fly-Net to work, we also need a list of probabilities for each token.

We can download our corpus here. It's in the form of a csv, which we can open in Pandas.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
df = pd.read_csv('/content/gdrive/MyDrive/Colab Notebooks/dataset/dex.csv')
print('shape:', df.shape)
df.head()

shape: (10177, 3)


Unnamed: 0,national_id,name,description
0,1,Bulbasaur,"For some time after its birth, it grows by gai..."
1,1,Bulbasaur,There is a plant seed on its back right from t...
2,1,Bulbasaur,It can go for days without eating a single mor...
3,1,Bulbasaur,Bulbasaur can be seen napping in bright sunlig...
4,1,Bulbasaur,"For some time after its birth, it grows by gai..."


To create longer pieces of text, I combine Pokedex entries coming from the same Pokemon. This list of concatenated entries will be our corpus.

In [None]:
corpus = df.groupby('name').description.apply(' '.join)
print('shape:', corpus.shape)
corpus.head()

shape: (898,)


name
Abomasnow    They appear when the snow flowers bloom. When ...
Abra         Sleeps 18 hours a day. If it senses danger, it...
Absol        Every time ABSOL appears before people, it is ...
Accelgor     When its body dries out, it weakens. So, to pr...
Aegislash    Generations of kings were attended by these Po...
Name: description, dtype: object

To tokenization the corpus, I use SpaCy.

In [None]:
nlp = spacy.load('en_core_web_sm')

The function for tokenization splits a piece of text into tokens using SpaCy, and ignores tokens that are punctuations, numbers, or stop words.

In [None]:
def create_tokens_from_text(text):
  """Tokenizes text by:
  - splitting with SpaCy
  - ignoring punctuations, numbers, and stop words
  """
  return [w.lemma_.lower() for w in nlp(text) 
          if not w.is_punct and 
          not w.like_num and 
          not w.lemma_.lower() in nlp.Defaults.stop_words]

In [None]:
create_tokens_from_text("They appear when the snow flowers bloom")

['appear', 'snow', 'flower', 'bloom']

To create the vocabulary, I iterate over the corpus and tokenize the texts using our tokenization function. The probabilities  of each token, or the percentage of the corpus they composed $\mathbf p$  are caculated. I want the intial $N_{voc}$ equals $20,000$ following the paper but the final vocabulary ends up being much smaller (~$6,500$).


In [None]:
tokens = []
init_vocabulary_size = 20000
batch_size = 100

# create tokens ['appear', 'snow', 'flower', 'bloom', ...]
for batch_start in tqdm(range(0, len(corpus), batch_size)):
  tokens += create_tokens_from_text(
      ' '.join(corpus.iloc[batch_start:batch_start+batch_size])
  )

# clip vocabulary if necessary and calculate probabilities
tokens_to_counts = dict(Counter(tokens).most_common(init_vocabulary_size))
total_count = sum(tokens_to_counts.values())
tokens_to_probabilities = {token : count / total_count for token, count in tokens_to_counts.items()}

# finalize vocabulary size, vocabulary and probabilities
vocabulary = list(tokens_to_probabilities.keys())
probabilities = xp.tile(xp.array(list(tokens_to_probabilities.values())), 2)
vocabulary_size = len(tokens_to_counts)

print(f'vocabulary size: {vocabulary_size}')

  0%|          | 0/9 [00:00<?, ?it/s]

vocabulary size: 6584


In [None]:
vocabulary.index("snow")

325

In [None]:
probabilities.shape

(13168,)

### Preparing trainset

We have the vocabulary and probabilities, now we need to create a trainset in the format accepted by Fruit-Fly-Net.

Given a vocabulary of $N_{voc}$ tokens, Fruit-Fly-Net takes in a token and its context in the form of a binary input vector $v^A$ of length $2 \times N_{voc}$, where the first $N_{voc}$ dimensions form a bag-of-words representation of the context words and the remaining $N_{voc}$ dimensions form a one-hot encoding of the target word. These input vectors are created from n-grams (which the authors refer to as w-grams) taken from the training corpus. The center element of each w-gram becomes the target while the surrounding elements comprise the context.

<table>
  <tr><td colspan=6><center>"Charizard breathes flames"</center></td></tr>
  <tr>
    <td>breathes</td><td>charizard</td><td>flames</td>
    <td>breathes</td><td>charizard</td><td>flames</td>
  </tr>
  <tr>
    <td>0</td><td>1</td><td>1</td>
    <td>1</td><td>0</td><td>0</td>
  </tr>
<table>


I create two helper functions for creating token ids (unique numbers for each token in our vocabulary) and the actual input training embeddings.

In [None]:
def create_token_ids_from_text(text, vocabulary):
  """Creates tokens from text then gets corresponding indices for tokens in the
  vocabulary
  """
  tokens = create_tokens_from_text(text)
  token_ids = [vocabulary.index(token) for token in tokens if token in vocabulary]
  return token_ids

In [None]:
create_token_ids_from_text("Charizard breathes flames", vocabulary)

[2138, 413, 63]

In [None]:
def create_training_embeddings_from_token_ids(token_ids, w_gram_size, vocabulary_size):
  """Creates several w-grams, then creates input training emebddings by having
  the middle token be the target the rest be the context
  """
  # create w-grams
  w_gram_size = min(w_gram_size, len(token_ids))
  middle_idx = w_gram_size//2
  w_grams = xp.array(np.lib.stride_tricks.sliding_window_view(token_ids, w_gram_size))
  w_grams[:, middle_idx] += vocabulary_size

  # create training embeddings
  training_embeddings = xp.zeros((w_grams.shape[0], vocabulary_size*2))
  training_embeddings[xp.indices(w_grams.shape)[0], w_grams] = 1
  training_embeddings = training_embeddings.astype(xp.bool_)

  return training_embeddings

Iterate over the corpus and create the training embeddings.The w-gram size is $15$.

In [None]:
w_gram_size = 15

training_embeddings = []
for text in tqdm(corpus):
  token_ids = create_token_ids_from_text(text, vocabulary)
  
  training_embeddings.append(
      csr_matrix(create_training_embeddings_from_token_ids(
          token_ids, w_gram_size, len(vocabulary)
      ))
  )
  # print(text)
  # print((create_training_embeddings_from_token_ids(token_ids, w_gram_size, len(vocabulary))))

training_embeddings = vstack(training_embeddings)

  0%|          | 0/898 [00:00<?, ?it/s]

In [None]:
# print(training_embeddings)

# Building Fruit-Fly-Net

Fruit-Fly-Net creates word embeddings by trying to learn the correlations between words and their contexts. Fruit-Fly-Net projects this input vector $\mathbf v^A$ to $K$ dimensions, of which the top $k$ activations are set to 1 while the rest are suppressed to 0. To update the projection weights, Fruit-Fly-Net requires $\mathbf p$.

In [None]:
from einops import rearrange, reduce

from typing import Union, Dict

try:
  cp_array_class = cp.core.core.ndarray
except:
  cp_array_class = cp._core.core.ndarray
Array = Union[np.ndarray, cp_array_class]


class Flyvec():
  """Fruit fly network as described in "Can a Fruit Fly Learn Word Embeddings?"
  (arXiv:2101.06887)
  Args:
    input_dim: number of input features
    output_dim: number of output features
    k: number of top output activations to keep
    lr: learning rate
  """

  def __init__(self, input_dim: int, output_dim: int,  k: int, lr: int) -> None:
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.k = k
    self.lr = lr
    self.weights = np.random.rand(output_dim, input_dim)
    self.training = True
    self.xp = np

  def __call__(self, x: Array, probs: Array) -> None:
    """Creates bio-hash. If `self.training` is True, updates weights after
    creating bio-hash.
    
    Args:
      x: input of shape batch x input_features
      probs: probabilities of each element in input; has shape 
             batch x input_features, where each row should sum up to 1
    """

    b = x.shape[0]
    activations = self.xp.inner(self.weights, x)
    out = self.xp.zeros_like(activations)
    # self.xp.put_along_axis(out, activations.argsort(axis=0)[-self.k:], 1, axis=0)
    out[rearrange(activations.argsort(axis=0)[-self.k:], 'k b -> (b k)'),
        rearrange(self.xp.indices((self.k, b))[1], 'k b -> (b k)')] = 1

    if self.training:
      self._backward(x, probs, activations)

    return rearrange(out, 'd b -> b d')

  def _backward(self, x: Array, probs: Array, activations: Array) -> None:
    """Updates weights
    Args:
      x: input of shape batch x input_features
      probs: probabilities of each element in input; has shape 
             batch x input_features, where each row should sum up to 1
      activations: output activations of shape output_features x batch
    """

    assert self.training, "Cannot update weights in eval mode"
    normalized_x = rearrange(x / probs, 'b d -> b () d')
    activations = rearrange(activations == activations.max(axis=0), 'd b -> b d ()')
    # activations = rearrange(activations == reduce(activations, 'd b -> b', 'max'), 'd b -> b d ()')
    normalized_weights = rearrange(self.xp.inner(self.weights, normalized_x), 'd b () -> b d ()')
    self.weights += self.lr * (activations * (normalized_x - normalized_weights * self.weights)).sum(axis=0)
    # self.weights += reduce(activations * (normalized_x - normalized_weights * self.weights), 'b o i -> o i', 'sum')
  
  def state_dict(self) -> Dict[str, Array]:
    """Returns dictionary of key "weights" and the weight array as the value"""

    return {'weights': cp.asnumpy(self.weights).copy()}

  def load_state_dict(self, state_dict: Dict[str, Array]) -> None:
    """Loads weights
    
    Args:
      state_dict: dictionary with key "weights" and weight array as the value
    """

    curr_shape, new_shape = self.weights.shape, state_dict['weights'].shape
    assert curr_shape == new_shape, f"Incorrect size for `weights`. Expected {curr_shape}, got {new_shape}."
    self.weights = state_dict['weights']
    self.to('cpu' if self.xp==np else 'gpu')

  def eval(self) -> None:
    """Turns off training mode"""

    self.training = False

  def train(self) -> None:
    """Turns on training mode"""

    self.training = True

  def to(self, device: str) -> None:
    """Moves weight array to device
    
    Args:
      device: device to move weights to; must be "cpu" or "gpu"
    """
    
    if device == 'cpu':
      self.weights = cp.asnumpy(self.weights)
    elif device == 'gpu':
      self.weights = cp.asarray(self.weights)
    else:
      raise ValueError("'device' must be either 'cpu' or 'gpu'")
    self.xp = cp.get_array_module(self.weights)


def bio_hash_loss(weights: Array, x: Array, probs: Array) -> Array:
  """Calculates bio-hash loss from "Bio-Inspired Hashing for Unsupervised 
  Similarity Search"
  (arXiv:2001.04907)
  Args:
    weights: model weights of shape output_features x input_features
    x: input of shape batch x input_features
    probs: probabilities of each element in input; has shape 
           batch x input_features, where each row should sum up to 1
  Returns:
    Array of energy/bio-hash loss for each input vector in batch
  """
  
  xp = cp.get_array_module(weights)
  max_activation_indices = xp.inner(weights, x).argmax(axis=0)
  max_activation_weights = weights[max_activation_indices]
  energy = -xp.inner(max_activation_weights, (x / probs)).diagonal()/xp.sqrt(xp.inner(max_activation_weights, max_activation_weights).diagonal())
  return energy.sum()

Let's instantiate Fruit-Fly-Net with $K=400$, $k=51$, and a learning rate of $1e-6$.

In [None]:
model = Flyvec(
  input_dim=vocabulary_size*2,  # input dimension size (vocab_size * 2)
  output_dim=400,               # output dimension size
  k=51,                         # top k cells to be left active in output layer
  lr=1e-6                       # learning rate (learning is performed internally)
)
model.to('gpu')

## Training Fruit-Fly-Net

For each epoch in train loop, shuffle trainset 

*   List item
*   List item

and iterate over each batch. For each batch, I feed the inputs to the model. The weight update is performed interally. The loss for every 1000 batches and at the end of each epoch is printed. The batch size equals $32$.

In [None]:
batch_size = 32

loss = 0
epochs = 10
for epoch in range(epochs):
  
  # shuffle trainset
  shuffled_idxs = xp.random.permutation(training_embeddings.shape[0])
  training_embeddings = training_embeddings[shuffled_idxs]
    
  for batch_start in tqdm(range(0, training_embeddings.shape[0], batch_size)):
    
    # train step
    input = training_embeddings[batch_start:batch_start+batch_size].toarray()
    model(input, probabilities)
    
    # get loss
    loss += bio_hash_loss(model.weights, input, probabilities)
    
    # print metrics every 1000 batches
    if batch_start//batch_size % 1000 == 999:
      print(f'epoch {epoch:2d} batch {batch_start//batch_size:4d}:\t{loss/(batch_size*1000):.3f}')
      loss = 0
        
  # print metrics after each epoch
  print(f'epoch {epoch:2d} batch {batch_start//batch_size:4d}:\t{loss/(batch_size*((training_embeddings.shape[0]//batch_size)%1000)):.3f}')
  loss = 0

  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  0 batch  999:	-1336.219
epoch  0 batch 1999:	-3323.667
epoch  0 batch 2701:	-8256.735


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  1 batch  999:	-14936.088
epoch  1 batch 1999:	-17652.456
epoch  1 batch 2701:	-18164.971


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  2 batch  999:	-18821.551
epoch  2 batch 1999:	-18177.335
epoch  2 batch 2701:	-18492.804


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  3 batch  999:	-18756.235
epoch  3 batch 1999:	-18698.504
epoch  3 batch 2701:	-18390.152


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  4 batch  999:	-18854.647
epoch  4 batch 1999:	-18540.150
epoch  4 batch 2701:	-18706.279


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  5 batch  999:	-18991.669
epoch  5 batch 1999:	-18735.543
epoch  5 batch 2701:	-18354.811


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  6 batch  999:	-18957.981
epoch  6 batch 1999:	-18742.592
epoch  6 batch 2701:	-18546.843


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  7 batch  999:	-18966.561
epoch  7 batch 1999:	-18774.675
epoch  7 batch 2701:	-18605.649


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  8 batch  999:	-18900.188
epoch  8 batch 1999:	-18818.466
epoch  8 batch 2701:	-18702.689


  0%|          | 0/2702 [00:00<?, ?it/s]

epoch  9 batch  999:	-19004.965
epoch  9 batch 1999:	-18660.482
epoch  9 batch 2701:	-18855.069


## Optional: Switching to a CPU runtime 

1. Save the vocabulary, probabilities, and model weights. Make sure to download the files after savng.

In [None]:
# save vocabulary
with open('vocab.pkl', 'wb') as vocab_file:
  pickle.dump(vocabulary, vocab_file)

# save probabilities
with open('prob.npy', 'wb') as prob_file:
  xp.save(prob_file, cp.asnumpy(probabilities))

# save model weights
with open('weights.pkl', 'wb') as file:
  pickle.dump(model.state_dict(), file)

2. Shut down this runtime by going to `Runtime > Factory Reset Runtime`. Then switch to a CPU runtime by going to `Runtime > Change Runtime Type > Hardware Accelerator > None`. After starting a new runtime, upload the `vocab.pkl`, `prob.npy`, and `weights.pkl` files to `/content/`.

3. Redo some installations, imports, and downloads.

In [None]:
pip install -U einops gradio spacy git+https://github.com/Ramos-Ramos/fruit-fly-net

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
from einops import rearrange
import cupy as cp
import numpy as xp
import numpy as np
import gradio as gr
import pandas as pd
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.notebook import tqdm

from collections import OrderedDict
import pickle

from fruit_fly_net import FruitFlyNet, bio_hash_loss

4. Redefine functions, reinstantiate classes, and load vocabulary, probabilities, and model weights.

In [None]:
# load vocabulary
with open('vocab.pkl', 'rb') as vocab_file:
  vocabulary = pickle.load(vocab_file)
  vocabulary_size = len(vocabulary)

# load probabilities
with open('prob.npy', 'rb') as prob_file:
  probabilities = xp.load(prob_file)

# tokenization functions and classes
nlp = spacy.load('en_core_web_sm')

def create_tokens_from_text(text):
  """Tokenizes text by:
  - splitting with SpaCy
  - ignoring punctuations, numbers, and stop words
  """
  return [w.lemma_.lower() for w in nlp(text) 
          if not w.is_punct and 
          not w.like_num and 
          not w.lemma_.lower() in nlp.Defaults.stop_words]

# reinstantiate model and load weights
model = FruitFlyNet(
  input_dim=vocabulary_size*2,  # input dimension size (vocab_size * 2)
  output_dim=400,               # output dimension size
  k=51,                         # top k cells to be left active in output layer
  lr=1e-6                       # learning rate (learning is performed internally)
)

with open('weights.pkl', 'rb') as file:
  model.load_state_dict(pickle.load(file))

## Interactive demo

Gradio is used to perform similarity search with static word embeddings.

The inputs for static embeddings differ from the input embeddings for training by ignoring context and only having a one-hot encoded target word in the remaining $N_{voc}$ dimensions of the vector. Let's start with a helper function that can create this type of embedding from a token and a vocabulary.

In [None]:
def create_static_input_embedding_from_token(token, vocabulary):
  token = (create_tokens_from_text(token)+[None])[0]
  id = None if token not in vocabulary else vocabulary.index(token) + len(vocabulary)
  input_embedding = xp.zeros(len(vocabulary)*2)
  if id is not None:
    input_embedding[id] = 1
  return input_embedding

Now let's create static input embeddings for each token in our vocabulary.

In [None]:
static_input_embeddings = []
for token in tqdm(vocabulary):
  static_input_embeddings.append(
      create_static_input_embedding_from_token(token, vocabulary)
  )
static_input_embeddings = xp.stack(static_input_embeddings)

  0%|          | 0/6584 [00:00<?, ?it/s]

In [None]:
static_input_embeddings

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

We then feed each input embedding into our model to create a static embedding for each token.

In [None]:
batch_size = 32
model.eval()
static_embeddings = []
for batch_start in tqdm(range(0, static_input_embeddings.shape[0], batch_size)):
  input = static_input_embeddings[batch_start:batch_start+batch_size]
  static_embeddings.append(model(input, probabilities))
static_embeddings = xp.concatenate(static_embeddings)

  0%|          | 0/206 [00:00<?, ?it/s]

In [None]:
static_embeddings[0]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
       1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
       1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [None]:
pip install -U einops gradio spacy git+https://github.com/Ramos-Ramos/fruit-fly-net

Collecting git+https://github.com/Ramos-Ramos/fruit-fly-net
  Cloning https://github.com/Ramos-Ramos/fruit-fly-net to /tmp/pip-req-build-3vq3zdfn
  Running command git clone -q https://github.com/Ramos-Ramos/fruit-fly-net /tmp/pip-req-build-3vq3zdfn
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [None]:
!python -m pip install git+https://github.com/kudkudak/word-embeddings-benchmarks

Collecting git+https://github.com/kudkudak/word-embeddings-benchmarks
  Cloning https://github.com/kudkudak/word-embeddings-benchmarks to /tmp/pip-req-build-sswd5luo
  Running command git clone -q https://github.com/kudkudak/word-embeddings-benchmarks /tmp/pip-req-build-sswd5luo
Building wheels for collected packages: web
  Building wheel for web (setup.py) ... [?25l[?25hdone
  Created wheel for web: filename=web-0.0.1-py3-none-any.whl size=42161 sha256=cc8b46812a43c3b995af74f3786d1490f838b7938bb1de6f53673c561671dd51
  Stored in directory: /tmp/pip-ephem-wheel-cache-_z9srq2v/wheels/14/2a/9d/933d6d58ab43e1ddca2e7fd80d9181adc506b6c09c56fbeeb7
Successfully built web
Installing collected packages: web
Successfully installed web-0.0.1


In [None]:
! pip3 install -U scikit-learn



Now we can find the $n$ most similar words for a given input word (ex. "fire", "wing", "night").

In [None]:
def get_top_similar_tokens_with_scores(token, top_similar):
  
  token = (create_tokens_from_text(token)+[None])[0]
  id = None if token not in vocabulary else vocabulary.index(token)
  if id is None:
    return {'out of vocabulary': 1.0}
  
  input_embedding = create_static_input_embedding_from_token(token, vocabulary)
  input_embedding = rearrange(input_embedding, 'd -> () d')
  
  model.eval()
  embedding = model(input_embedding, probabilities)
  
  similarities = cosine_similarity(
      cp.asnumpy(embedding), cp.asnumpy(static_embeddings)
  )
  similarities = rearrange(similarities, '() i -> i')
  
  current_vocabulary = vocabulary
  if id is not None:
    similarities = np.concatenate((similarities[:id], similarities[id+1:]))
    current_vocabulary = current_vocabulary[:id]+current_vocabulary[id+1:]

  top_similar_ids = similarities.argsort(kind='stable')[-top_similar:].tolist()
  top_similar_scores = similarities[top_similar_ids]
  top_similar_tokens = [current_vocabulary[id] for id in top_similar_ids]
  return OrderedDict(zip(top_similar_tokens, top_similar_scores))

r = gr.inputs.Slider(1, 20, step=1, default=10)
gr.Interface(fn=get_top_similar_tokens_with_scores, inputs=['text', r], outputs='label').launch()

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
Running on public URL: https://18130.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7860/',
 'https://18130.gradio.app')