Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1 with notebook support #69

Merged
merged 10 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions Untitled12.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "a17c690c-5a86-43ec-88f9-73e493ca5aa2",
"metadata": {},
"outputs": [],
"source": [
"# %pip install sentence-transformers umap-learn embetter"
]
},
{
"cell_type": "markdown",
"id": "2184ee7a-18a5-4edd-934f-6d3c16fe9c09",
"metadata": {},
"source": [
"# Intro to bulk from a notebook\n",
"\n",
"In an attempt to come to a quick demo, we ran some code beforehand that does some encoding. \n",
"\n",
"<details>\n",
" <summary><b>Show me the code.</b></summary>\n",
"\n",
"```python \n",
"import pandas as pd\n",
"from umap import UMAP\n",
"from sklearn.pipeline import make_pipeline \n",
"\n",
"# pip install \"embetter[text]\"\n",
"from embetter.text import SentenceEncoder\n",
"\n",
"# Build a sentence encoder pipeline with UMAP at the end.\n",
"enc = SentenceEncoder('all-MiniLM-L6-v2')\n",
"umap = UMAP()\n",
"\n",
"text_emb_pipeline = make_pipeline(\n",
" enc, umap\n",
")\n",
"\n",
"# Load sentences\n",
"sentences = list(pd.read_csv(\"tests/data/text.csv\")['text'])\n",
"\n",
"# Calculate embeddings \n",
"X_tfm = text_emb_pipeline.fit_transform(sentences)\n",
"\n",
"# Write to disk. Note! Text column must be named \"text\"\n",
"df = pd.DataFrame({\"text\": sentences})\n",
"df['x'] = X_tfm[:, 0]\n",
"df['y'] = X_tfm[:, 1]\n",
"\n",
"X = enc.transform(sentences)\n",
"```\n",
"\n",
"This gives us a dataframe `df` that contains sentences, but also contains 2D UMAP representations of sentence embeddings. We also have a sentence encoder `enc` loaded and we also have access to our original embeddings `X`. Computing these can take a while on a CPU so we will store these on disk.\n",
"\n",
"```python\n",
"import numpy as np\n",
"\n",
"np.save(\"utils/X\", X)\n",
"np.save(\"utils/X_tfm\", X_tfm)\n",
"df.to_csv(\"utils/df.csv\", index=False)\n",
"```\n",
"\n",
"</details>\n",
"\n",
"\n",
"For now, all you need to know is that we have some files on disk with some precomputer embeddings and representations that we can use to make a nice interactive plot. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7e5ba2db-883f-4184-86dc-95d3560b2c24",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"id": "0abb4012-9c2b-4c68-ba72-91f5f61ddd0b",
"metadata": {},
"source": [
"So lets load some data first."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ae6a77ec-33ec-446d-8db9-147677c93492",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"X = np.load(\"utils/X.npy\")\n",
"X_tfm = np.load(\"utils/X_tfm.npy\")\n",
"df = pd.read_csv(\"utils/df.csv\")"
]
},
{
"cell_type": "markdown",
"id": "33f196ed-7e72-415b-b749-c5f8be20d99c",
"metadata": {},
"source": [
"Next, lets use these variables to conjure up a basic text explorer. This will allow us to quickly explore the clusters that appear in our data. You can hold the mouse cursor to go into selection mode and when you select items you will see a random subset appear on the right. You can resample from your selection by clicking the resample button."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "daaf355b-ac0c-4399-be72-647eea085a38",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adf74d2814ba41e093e3c926e07b4556",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width…"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from bulk.widgets import BaseTextExplorer\n",
"\n",
"widget = BaseTextExplorer(df)\n",
"widget.show()"
]
},
{
"cell_type": "markdown",
"id": "3301a2b2-1bec-400f-bc0d-09885e251178",
"metadata": {},
"source": [
"Being able to explore these clusters is neat, but it feels like we might more easily explore everything if we have some more tools at our disposal. In particular, we want to have an encoder around so that we may use queries in our embedded space. \n",
"\n",
"The UI below will allow us to explore much more interactively."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "9c48e600-59f8-44fc-acc6-a813b1547e45",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/vincent/Development/bulk/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
"from embetter.text import SentenceEncoder\n",
"\n",
"enc = SentenceEncoder('all-MiniLM-L6-v2')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "c07e82b7-0735-4c12-bbd0-591efca09436",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2cc738d482b4feb98a9eb7f3b45b28b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Pay attention here! The rows in df needs to align with the rows in X!\n",
"widget = BaseTextExplorer(df, X=X, encoder=enc)\n",
"widget.show()"
]
},
{
"cell_type": "markdown",
"id": "cc7abf74-5dd1-4d14-9d54-3b2422389472",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Thanks to tools like ipywidget and anywidget, we can really start building some tools to make the notebook hard to beat as the go-to place for your data needs. My primary interest is to work on tools for data quality and being able to select datapoints in bulk feels like a great place to start. Maybe you can find an interesting subset to annotate first, maybe you get suprised when you see two distinct clusters that should be one. All that good stuff can happen in the notebook.\n",
"\n",
"More UI will follow, but this `BaseTextExplorer` feels like a nice place to start!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
62 changes: 62 additions & 0 deletions bulk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import jscatter
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text
from IPython.display import display
from sklearn.metrics.pairwise import cosine_similarity

class BaseTextExplorer:
"""
Early preview of Jupyter Widget explorer.
"""
def __init__(self, dataf, X=None, encoder=None):
self.dataf = dataf
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
self.sample_btn = Button(description='resample')
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])])
self.X = X
self.encoder = encoder

if self.encoder and (self.X is not None):
self.text_input = Text(value='', placeholder='Type something', description='String:')
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])])

def update_text(change):
X_tfm = encoder.transform([self.text_input.value])
dists = cosine_similarity(X, X_tfm).reshape(1, -1)
self.dists = dists
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min())
self.scatter.color(by=norm_dists[0])
self.scatter.size(by=norm_dists[0])

self.text_input.observe(update_text)

self.scatter.widget.observe(lambda d: self.update(), ['selection'])
self.sample_btn.on_click(lambda d: self.update())

def show(self):
return self.elem

def update(self):
if len(self.scatter.selection()) > 10:
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
else:
texts = self.dataf.iloc[self.scatter.selection()]["text"]
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

def observe(self, func):
self.scatter.widget.observe(func, ['selection'])

@property
def selected_idx(self):
return self.scatter.selection()

@property
def selected_texts(self):
return list(self.dataf.iloc[self.selection_idx]["text"])

@property
def selected_dataframe(self):
return self.dataf.iloc[self.selection_idx]

def _repr_html_(self):
return display(self.elem)
64 changes: 64 additions & 0 deletions bulk/widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import jscatter
import numpy as np
import pandas as pd
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text
from IPython.display import display
from sklearn.metrics.pairwise import cosine_similarity

class BaseTextExplorer:
"""
Interface for basic text exploration in embedded space.
"""
def __init__(self, dataf, X=None, encoder=None):
self.dataf = dataf
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
self.sample_btn = Button(description='resample')
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])])
self.X = X
self.encoder = encoder

if self.encoder and (self.X is not None):
self.text_input = Text(value='', placeholder='Type something', description='String:')
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])])

def update_text(change):
X_tfm = encoder.transform([self.text_input.value])
dists = cosine_similarity(X, X_tfm).reshape(1, -1)
self.dists = dists
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min())
self.scatter.color(by=norm_dists[0])
self.scatter.size(by=norm_dists[0])

self.text_input.observe(update_text)

self.scatter.widget.observe(lambda d: self.update(), ['selection'])
self.sample_btn.on_click(lambda d: self.update())

def show(self):
return self.elem

def update(self):
if len(self.scatter.selection()) > 10:
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
else:
texts = self.dataf.iloc[self.scatter.selection()]["text"]
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

def observe(self, func):
self.scatter.widget.observe(func, ['selection'])

@property
def selected_idx(self):
return self.scatter.selection()

@property
def selected_texts(self):
return list(self.dataf.iloc[self.selection_idx]["text"])

@property
def selected_dataframe(self):
return self.dataf.iloc[self.selection_idx]

def _repr_html_(self):
return display(self.elem)
Loading
Loading