In [15]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1
%env ANYWIDGET_DEV=1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: ANYWIDGET_HMR=1
env: ANYWIDGET_DEV=1


In [16]:
from typing import Any

import networkx as nx
import numpy as np
from datasets import load_dataset
from scipy.cluster.hierarchy import dendrogram, fcluster, linkage
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import DistanceMetric

from seq import Widget
from seq.data.language import get_featured_ids, get_ids, get_tokenizer


In [17]:
ds = load_dataset("ajaykarthick/imdb-movie-reviews")["test"]

In [18]:
tokenizer = get_tokenizer()
ids, tokens = get_ids(ds["review"], tokenizer, max_tokens=32)
featured_ids = get_featured_ids(ids, tokenizer, n_features=10)
labels = [{"id": i, "label": tokenizer.id_to_token(i)} for i in featured_ids]

100%|██████████| 10000/10000 [00:03<00:00, 2543.79it/s]


In [19]:
w = Widget(
  sequences=ids,
  labels=labels,
)
w

Widget(labels=[{'id': 59, 'label': 'movie'}, {'id': 5, 'label': 'film'}, {'id': 30, 'label': 'one'}, {'id': 10…

In [20]:
def mask_non_featured_sequences(
  sequences: list[Any], label_ids: list[int]
) -> np.ndarray:
  sequences = np.array(sequences)
  mask = np.isin(sequences, label_ids)
  sequences = sequences * mask

  return sequences[np.sum(mask, axis=1) > 0]


labeld_sequences = mask_non_featured_sequences(ids, featured_ids)
w.sequences = labeld_sequences.tolist()

In [21]:
def distance(x: np.ndarray) -> np.ndarray:
  seq1 = x[:, np.newaxis, :]
  seq2 = x[np.newaxis, :, :]
  distances = seq1 != seq2
  zero_distances = np.logical_and(seq1 == 0, seq2 == 0)
  distances = np.maximum(distances, zero_distances)

  distances = np.sum(distances, axis=2) / x.shape[1]
  np.fill_diagonal(distances, 0)

  return distances


In [22]:
def cluster_sequences(sequences: list[Any]) -> np.ndarray:
  sequences = np.array(sequences)
  dist = distance(sequences)
  # dist = DistanceMetric.get_metric("hamming").pairwise(sequences)
  dist = squareform(dist)
  linkage_matrix = linkage(dist, method="average")
  dendrogram_data = dendrogram(linkage_matrix, no_plot=True)
  order = dendrogram_data["leaves"]
  order = np.array(order)
  return np.array(sequences)[order]


clustered_sequences = cluster_sequences(labeld_sequences)
w.sequences = clustered_sequences.tolist()

In [23]:
def mask_sequences(sequences: list[Any], window_length: int) -> list[Any]:
  unmasked = np.array(sequences.copy())
  masks = []
  for i in range(2 * window_length):
    left = unmasked[i : -(2 * window_length - i), :]
    right = unmasked[i + 1 : -(2 * window_length - i - 1) or None, :]
    masks.append(left != right)

  mask = np.logical_or.reduce(masks)
  unmasked[window_length:-window_length, :][mask] = 0
  return unmasked


masked_sequences = mask_sequences(clustered_sequences, window_length=1)
w.sequences = masked_sequences.tolist()

In [26]:
def filter_sequences(sequences: list[Any], filter_length: int) -> list[Any]:
  sequences = np.array(sequences)
  mask = np.sum((sequences != 0), axis=1) > filter_length

  sequences = sequences[mask]

  return sequences


filtered_sequences = filter_sequences(masked_sequences, filter_length=2)
print(filtered_sequences.shape)
w.sequences = filtered_sequences.tolist()

(35, 32)


In [27]:
def sort_sequences(sequences: list[Any]):
  sequences = np.array(sequences)
  unique_sequences, count = np.unique(sequences, axis=0, return_counts=True)
  dist_matrix = distance(unique_sequences)

  G = nx.Graph()
  for i in range(len(dist_matrix)):
    for j in range(i + 1, len(dist_matrix)):
      G.add_edge(i, j, weight=dist_matrix[i, j])

  answer = nx.algorithms.approximation.christofides(G)
  # answer = nx.algorithms.approximation.traveling_salesman_problem(G)
  sorted_unique_sequences = unique_sequences[np.array(answer[:-1])]
  sorted_original_sequences = []

  for i, unique_sequence in enumerate(sorted_unique_sequences):
    sorted_original_sequences.extend([unique_sequence] * count[i])

  return np.array(sorted_original_sequences)


sorted_sequences = sort_sequences(filtered_sequences)
w.sequences = sorted_sequences.tolist()