Let's use Elixir
machine learning capabilities
to build an application
that performs image captioning
and semantic search
to look for uploaded images
with your voice! 🎙️
- Image Captioning & Semantic Search in
Elixir
- Why? 🤷
- What? 💭
- Who? 👤
- How? 💻
- Prerequisites
- 🌄 Image Captioning in
Elixir
- 0. Creating a fresh
Phoenix
project - 1. Installing initial dependencies
- 2. Adding
LiveView
capabilities to our project - 3. Receiving image files
- 4. Integrating
Bumblebee
- 4.4 Updating the view
- 5. Final Touches
- 6. What about other models?
- 7. How do I deploy this thing?
- 8. Showing example images
- 8.2 Handling the example images list event inside our LiveView
- 9. Store metadata and classification info
- 10. Adding double MIME type check and showing feedback to the person in case of failure
- 11. Benchmarking image captioning models
- 0. Creating a fresh
- 🔍 Semantic search
- Please star the repo! ⭐️
Whilst building our app,
we consider images
an essential medium of communication.
We needed a fully-offline capable (no 3rd party APIs/Services) image captioning service
using state-of-the-art pre-trained image and embedding models to describe images uploaded in our
App
.
By adding a way of captioning images, we make it easy for people to suggest meta tags that describe images so they become searchable.
A step-by-step tutorial building a fully functional
Phoenix LiveView
web application that allows anyone
to upload an image and have it described
and searchable.
In addition to this, the app will allow the person to record an audio which describes the image they want to find.
The audio will be transcribed into text
and be semantically queryable.
We do this by encoding the image captions
as vectors and running knn search
on them.
We'll be using three different models:
- Salesforce's BLIP model
blip-image-captioning-large
for image captioning. - OpenAI's speech recognition model
whisper-small
. sentence-transformers/paraphrase-MiniLM-L6-v2
embedding model.
This tutorial is aimed at Phoenix
beginners
who want to start exploring the machine-learning capabilities
of the Elixir language within a Phoenix
application.
We propose to use pre-trained models from Hugging Face via Bumblebee
and grasp how to:
- run a model, in particular image captioning.
- how to use embeddings.
- how to run a semantic search using an Approximate Nearest Neighbour algorithm.
If you are completely new to Phoenix
and LiveView
,
we recommend you follow the LiveView
Counter Tutorial:
dwyl/phoenix-liveview-counter-tutorial
In these chapters, we'll go over the development process of this small application. You'll learn how to do this yourself, so grab some coffee and let's get cracking!
This section will be divided into two sections. One will go over image captioning while the second one will expand the application by adding semantic search.
This tutorial requires you to have Elixir
and Phoenix
installed.
If you don't, please see how to install Elixir and Phoenix.
This guide assumes you know the basics of Phoenix
and have some knowledge of how it works.
If you don't, we highly suggest you follow our other tutorials first, e.g: github.com/dwyl/phoenix-chat-example
In addition to this, some knowledge of AWS
- what it is, what an S3
bucket is/does - is assumed.
Note
If you have questions or get stuck, please open an issue! /dwyl/image-classifier/issues
In this section, we'll start building our application with
Bumblebee
that supports Transformer models. At the end of this section, you'll have a fully functional application that receives an image, processes it accordingly and captions it.
Let's create a fresh Phoenix
project.
Run the following command in a given folder:
mix phx.new . --app app --no-dashboard --no-ecto --no-gettext --no-mailer
We're running mix phx.new
to generate a new project without a dashboard and mailer (email) service,
since we don't need those features in our project.
After this, if you run mix phx.server
to run your server, you should be able to see the following page.
We're ready to start building.
Now that we're ready to go, let's start by adding some dependencies.
Head over to mix.exs
and add the following dependencies
to the deps
section.
{:bumblebee, "~> 0.5.0"},
{:exla, "~> 0.7.0"},
{:nx, "~> 0.7.0 "},
{:hnswlib, "~> 0.1.5"},
-
bumblebee
is a framework that will allows us to integrateTransformer Models
inPhoenix
. TheTransformers
(from Hugging Face) are APIs that allow us to easily download and use pre-trained models. TheBumblebee
package aims to support all Transformer Models, even if some are still lacking. You may check which ones are supported by visitingBumblebee
's repository or by visiting https://jonatanklosko-bumblebee-tools.hf.space/apps/repository-inspector and checking if the model is currently supported. -
Nx
is a library that allows us to work withNumerical Elixir
, the Elixir's way of doing numerical computing. It supports tensors and numericla computations. -
EXLA
is the Elixir implementation of Google's XLA, a compiler that provides faster linear algebra calculations withTensorFlow
models. This backend compiler is needed forNx
. We are installingEXLA
because it allows us to compile models just-in-time and run them on CPU and/or GPU. -
Vix
is an Elixir extension for libvips, an image processing library.
In config/config.exs
, let's add our :nx
configuration
to use EXLA
.
config :nx, default_backend: EXLA.Backend
As it stands, our project is not using LiveView
.
Let's fix this.
This will launch a super-powered process that establishes a WebSocket connection between the server and the browser.
In lib/app_web/router.ex
, change the scope "/"
to the following.
scope "/", AppWeb do
pipe_through :browser
live "/", PageLive
end
Instead of using the PageController
,
we are going to be creating PageLive
,
a LiveView
file.
Let's create our LiveView
files.
Inside lib/app_web
,
create a folder called live
and create the following file
page_live.ex
.
#/lib/app_web/live/page_live.ex
defmodule AppWeb.PageLive do
use AppWeb, :live_view
@impl true
def mount(_params, _session, socket) do
{:ok, socket}
end
end
This is a simple LiveView
controller.
In the same live
folder,
create a file called page_live.html.heex
and use the following code.
<div
class="h-full w-full px-4 py-10 flex justify-center sm:px-6 sm:py-28 lg:px-8 xl:px-28 xl:py-32"
>
<div
class="flex justify-center items-center mx-auto max-w-xl w-[50vw] lg:mx-0"
>
<form>
<div class="space-y-12">
<div>
<h2 class="text-base font-semibold leading-7 text-gray-900">
Image Classifier
</h2>
<p class="mt-1 text-sm leading-6 text-gray-600">
Drag your images and we'll run an AI model to caption it!
</p>
<div class="mt-10 grid grid-cols-1 gap-x-6 gap-y-8 sm:grid-cols-6">
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
>
<div class="text-center">
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<span>Upload a file</span>
<input
id="file-upload"
name="file-upload"
type="file"
class="sr-only"
/>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">
PNG, JPG, GIF up to 5MB
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</form>
</div>
</div>
This is a simple HTML form that uses
Tailwind CSS
to enhance the presentation of the upload form.
We'll also remove the unused header of the page layout,
while we're at it.
Locate the file lib/app_web/components/layouts/app.html.heex
and remove the <header>
class.
The file should only have the following code:
<main class="px-4 py-20 sm:px-6 lg:px-8">
<div class="mx-auto max-w-2xl">
<.flash_group flash={@flash} /> <%= @inner_content %>
</div>
</main>
Now you can safely delete the lib/app_web/controllers
folder,
which is no longer used.
If you run mix phx.server
,
you should see the following screen:
This means we've successfully added LiveView
and changed our view!
Now, let's start by receiving some image files.
With LiveView
,
we can easily do this by using
allow_upload/3
when mounting our LiveView
.
With this function, we can easily accept
file uploads with progress.
We can define file types, max number of entries,
max file size,
validate the uploaded file and much more!
Firstly,
let's make some changes to
lib/app_web/live/page_live.html.heex
.
<div
class="h-full w-full px-4 py-10 flex justify-center sm:px-6 sm:py-28 lg:px-8 xl:px-28 xl:py-32"
>
<div
class="flex justify-center items-center mx-auto max-w-xl w-[50vw] lg:mx-0"
>
<div class="space-y-12">
<div class="border-gray-900/10 pb-12">
<h2 class="text-base font-semibold leading-7 text-gray-900">
Image Classification
</h2>
<p class="mt-1 text-sm leading-6 text-gray-600">
Do simple captioning with this
<a
href="https://hexdocs.pm/phoenix_live_view/Phoenix.LiveView.html"
class="font-mono font-medium text-sky-500"
>LiveView</a
>
demo, powered by
<a
href="https://github.com/elixir-nx/bumblebee"
class="font-mono font-medium text-sky-500"
>Bumblebee</a
>.
</p>
<!-- File upload section -->
<div class="mt-10 grid grid-cols-1 gap-x-6 gap-y-8 sm:grid-cols-6">
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
phx-drop-target="{@uploads.image_list.ref}"
>
<div class="text-center">
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form phx-change="validate" phx-submit="save">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list}
class="hidden" /> Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">
PNG, JPG, GIF up to 5MB
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
We've added a few features:
- we used
<.live_file_input/>
forLiveView
file upload. We've wrapped this component with an element that is annotated with thephx-drop-target
attribute pointing to the DOMid
of the file input. - because we used
<.live_file_input/>
, we need to annotate its wrapping element withphx-submit
andphx-change
, as per hexdocs.pm/phoenix_live_view/uploads.html#render-reactive-elements
Because we've added these bindings,
we need to add the event handlers in
lib/app_web/live/page_live.ex
.
Open it and update it to:
defmodule AppWeb.PageLive do
use AppWeb, :live_view
@impl true
def mount(_params, _session, socket) do
{:ok,
socket
|> assign(label: nil, upload_running?: false, task_ref: nil)
|> allow_upload(:image_list,
accept: ~w(image/*),
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1,
chunk_size: 64_000
)}
end
@impl true
def handle_event("validate", _params, socket) do
{:noreply, socket}
end
@impl true
def handle_event("remove-selected", %{"ref" => ref}, socket) do
{:noreply, cancel_upload(socket, :image_list, ref)}
end
@impl true
def handle_event("save", _params, socket) do
{:noreply, socket}
end
defp handle_progress(:image_list, entry, socket) when entry.done? do
uploaded_file =
consume_uploaded_entry(socket, entry, fn %{path: _path} = _meta ->
{:ok, entry}
end)
{:noreply, socket}
end
defp handle_progress(:image_list, _, socket), do: {:noreply, socket}
end
-
when
mount/3
ing the LiveView, we are creating three socket assigns:label
pertains to the model prediction;upload_running?
is a boolean referring to whether the model is running or not;task_ref
refers to the reference of the task that was created for image classification (we'll delve into this further later down the line). Additionally, we are using theallow_upload/3
function to define our upload configuration. The most important settings here areauto_upload
set totrue
and theprogress
fields. By configuring these two properties, we are tellingLiveView
that whenever the person uploads a file, it is processed immediately and consumed. -
the
progress
field is handled by thehandle_progress/3
function. It receives chunks from the client with a build-inUploadWriter
function (as explained in the docs). When the chunks are all consumed, we get the booleanentry.done? == true
. We consume the file in this function by usingconsume_uploaded_entry/3
. The anonymous function returns{:ok, data}
or{:postpone, message}
. Whilst consuming the entry/file, we can access its path and then use its content. For now, we don't need to use it. But we will in the future to feed our image classifier with it! After the callback function is executed, this function "consumes the entry", essentially deleting the image from the temporary folder and removing it from the uploaded files list.() -
the
"validate"
,"remove-selected"
,"save"
event handlers are called whenever the person uploads the image, wants to remove it from the list of uploaded images and when wants to submit the form, respectively. You may see that we're not doing much with these handlers; we're simply replying with a:noreply
because we don't need to do anything with them.
And that's it!
If you run mix phx.server
, nothing will change.
Now here comes the fun part! It's time to do some image captioning! 🎉
We first need to add some initial setup in the
lib/app/application.ex
file.
Head over there and change
the start
function like so:
@impl true
def start(_type, _args) do
children = [
# Start the Telemetry supervisor
AppWeb.Telemetry,
# Start the PubSub system
{Phoenix.PubSub, name: App.PubSub},
{Nx.Serving, serving: serving(), name: ImageClassifier},
# Start the Endpoint (http/https)
AppWeb.Endpoint
]
opts = [strategy: :one_for_one, name: App.Supervisor]
Supervisor.start_link(children, opts)
end
def serving do
{:ok, model_info} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})
Bumblebee.Vision.image_classification(model_info, featurizer,
top_k: 1,
compile: [batch_size: 10],
defn_options: [compiler: EXLA]
)
end
We are using Nx.Serving
, which simply allows us to encapsulate tasks; it can be networking, machine learning, data processing or any other task.
In this specific case, we are using it to batch requests. This is extremely useful and important because we are using models that typically run on GPU. The GPU is really good at parallelizing tasks. Therefore, instead of sending an image classification request one by one, we can batch them/bundle them together as much as we can and then send it over.
We can define the batch_size
and batch_timeout
with Nx.Serving
.
We're going to use the default values, hence why we're not explicitly defining them.
With Nx.Serving
, we define a serving/0
function
that is then used by it, which in turn is executed in the supervision tree since we declare it as a child in the Application module.
In the serving/0
function, we are loading theResNet-50
model and its featurizer.
Note
A featurizer
can be seen as a
Feature Extractor
.
It is essentially a component that is responsible for converting input data
into a format that can be processed by a pre-trained language model.
It takes raw information and performs various transformations, such as tokenization, padding, and encoding to prepare the data for model training or inference.
Lastly, this function returns a serving for image classification by calling image_classification/3
, where we can define our compiler and task batch size.
We gave our serving function the name ImageClassifier
as declared in the Application module.
Now we're ready to send the image to the model and get a prediction of it!
Every time we upload an image, we are going to run async processing. This means that the task responsible for image classification will be run in another process, thus asynchronously, meaning that the LiveView won't have to wait for this task to finish to continue working.
For this scenario, we are going to be using the
Task
module to spawn processes to complete this task.
Go to lib/app_web/live/page_live.ex
and change the following code.
def handle_progress(:image_list, entry, socket) when entry.done? do
# Consume the entry and get the tensor to feed to classifier
tensor = consume_uploaded_entry(socket, entry, fn %{} = meta ->
{:ok, vimage} = Vix.Vips.Image.new_from_file(meta.path)
pre_process_image(vimage)
end)
# Create an async task to classify the image
task = Task.async(fn -> Nx.Serving.batched_run(ImageClassifier, tensor) end)
# Update socket assigns to show spinner whilst task is running
{:noreply, assign(socket, upload_running?: true, task_ref: task.ref)}
end
@impl true
def handle_info({ref, result}, %{assigns: %{task_ref: ref}} = socket) do
# This is called everytime an Async Task is created.
# We flush it here.
Process.demonitor(ref, [:flush])
# And then destructure the result from the classifier.
%{predictions: [%{label: label}]} = result
# Update the socket assigns with result and stopping spinner.
{:noreply, assign(socket, label: label, upload_running?: false)}
end
Note
The pre_process_image/1
function is yet to be defined.
We'll do that in the following section.
In the handle_progress/3
function,
whilst we are consuming the image,
we are first converting it to a
Vix.Vips.Image
Struct
using the file path.
We then feed this image to the pre_process_image/1
function that we'll implement later.
What's important is to notice this line:
task = Task.async(fn -> Nx.Serving.batched_run(ImageClassifier, tensor) end)
We are using
Task.async/1
to call our Nx.Serving
build function ImageClassifier
we've defined earlier,
thus initiating a batched run with the image tensor.
While the task is spawned,
we update the socket assigns with the reference to the task (:task_ref
)
and update the :upload_running?
assign to true
,
so we can show a spinner or a loading animation.
When the task is spawned using Task.async/1
,
a couple of things happen in the background.
The new process is monitored by the caller (our LiveView
),
which means that the caller will receive a
{:DOWN, ref, :process, object, reason}
message once the process it is monitoring dies.
And, a link is created between both processes.
Therefore,
we don't need to use
Task.await/2
.
Instead, we create a new handler to receive the aforementioned.
That's what we're doing in the
handle_info({ref, result}, %{assigns: %{task_ref: ref}} = socket)
function.
The received message contains a {ref, result}
tuple,
where ref
is the monitor’s reference.
We use this reference to stop monitoring the task,
since we received the result we needed from our task
and we can discard an exit message.
In this same function, we destructure the prediction
from the model and assign it to the socket assign :label
and set :upload_running?
to false
.
Quite beautiful, isn't it?
With this, we don't have to worry if the person closes the browser tab.
The process dies (as does our LiveView
),
and the work is automatically cancelled,
meaning no resources are spent
on a process for which nobody expects a result anymore.
When a task is spawned using Task.async/2
,
it is linked to the caller.
Which means that they're related:
if one dies, the other does too.
We ought to take this into account when developing our application.
If we don't have control over the result of the task,
and we don't want our LiveView
to crash if the task crashes,
we must use a different alternative to spawn our task -
Task.Supervisor.async_nolink/3
can be used for this effect,
meaning we can use it if we want to make sure
our LiveView
won't die and the error is reported,
even if the task crashes.
We've chosen Task.async/2
for this very reason.
We are doing something that takes time/is expensive
and we want to stop the task if LiveView
is closed/crashes.
However, if you are building something
like a report that has to be generated even if the person closes the browser tab,
this is not the right solution.
We are spawning async tasks by calling Task.async/1
.
This is creating an unsupervised task.
Although it's plausible for this simple app,
it's best for us to create a
Supervisor
that manages their child tasks.
This gives more control over the execution
and lifetime of the child tasks.
Additionally, it's better to have these tasks supervised
because it makes it possible to create tests for our LiveView
.
For this, we need to make a couple of changes.
First, head over to lib/app/application.ex
and add a supervisor to the start/2
function children array.
def start(_type, _args) do
children = [
AppWeb.Telemetry,
{Phoenix.PubSub, name: App.PubSub},
{Nx.Serving, serving: serving(), name: ImageClassifier},
{Task.Supervisor, name: App.TaskSupervisor}, # add this line
AppWeb.Endpoint
]
opts = [strategy: :one_for_one, name: App.Supervisor]
Supervisor.start_link(children, opts)
end
We are creating a Task.Supervisor
with the name App.TaskSupervisor
.
Now, in lib/app_web/live/page_live.ex
,
we create the async task like so:
task = Task.Supervisor.async(App.TaskSupervisor, fn -> Nx.Serving.batched_run(ImageClassifier, tensor) end)
We are now using
Task.Supervisor.async
,
passing the name of the supervisor defined earlier.
And that's it! We are creating async tasks like before, the only difference is that they're now supervised.
In tests, you can create a small module that waits for the tasks to be completed.
defmodule AppWeb.SupervisorSupport do
@moduledoc """
This is a support module helper that is meant to wait for all the children of a supervisor to complete.
If you go to `lib/app/application.ex`, you'll see that we created a `TaskSupervisor`, where async tasks are spawned.
This module helps us to wait for all the children to finish during tests.
"""
@doc """
Find all children spawned by this supervisor and wait until they finish.
"""
def wait_for_completion() do
pids = Task.Supervisor.children(App.TaskSupervisor)
Enum.each(pids, &Process.monitor/1)
wait_for_pids(pids)
end
defp wait_for_pids([]), do: nil
defp wait_for_pids(pids) do
receive do
{:DOWN, _ref, :process, pid, _reason} -> wait_for_pids(List.delete(pids, pid))
end
end
end
You can call AppWeb.SupervisorSupport.wait_for_completion()
in unit tests so they wait for the tasks to complete.
In our case,
we do that until the prediction is made.
As we've noted before, we need to pre-process the image before passing it to the model. For this, we have three main steps:
- removing the
alpha
out of the image, flattening it out. - convert the image to
sRGB
colourspace. This is needed to ensure that the image is consistent and aligns with the model's training data images. - set the representation of the image as a
Tensor
toheight, width, bands
. The image tensor will then be organized as a three-dimensional array, where the first dimension represents the height of the image, the second refers to the width of the image, and the third pertains to the different spectral bands/channels of the image.
Our pre_process_image/1
function will implement these three steps.
Let's implement it now!
In lib/app_web/live/page_live.ex
,
add the following:
defp pre_process_image(%Vimage{} = image) do
# If the image has an alpha channel, flatten it:
{:ok, flattened_image} = case Vix.Vips.Image.has_alpha?(image) do
true -> Vix.Vips.Operation.flatten(image)
false -> {:ok, image}
end
# Convert the image to sRGB colourspace ----------------
{:ok, srgb_image} = Vix.Vips.Operation.colourspace(flattened_image, :VIPS_INTERPRETATION_sRGB)
# Converting image to tensor ----------------
{:ok, tensor} = Vix.Vips.Image.write_to_tensor(srgb_image)
# We reshape the tensor given a specific format.
# In this case, we are using {height, width, channels/bands}.
%Vix.Tensor{data: binary, type: type, shape: {x, y, bands}} = tensor
format = [:height, :width, :bands]
shape = {x, y, bands}
final_tensor =
binary
|> Nx.from_binary(type)
|> Nx.reshape(shape, names: format)
{:ok, final_tensor}
end
The function receives a Vix
image,
as detailed earlier.
We use flatten/1
to flatten the alpha out of the image.
The resulting image has its colourspace changed
by calling colourspace/3
,
where we change the to sRGB
.
The colourspace-altered image is then converted to a
tensor,
by calling
write_to_tensor/1
.
We then reshape the tensor according to the format that was previously mentioned.
This function returns the processed tensor, that is then used as input to the model.
All that's left is updating the view
to reflect these changes we've made to the LiveView
.
Head over to lib/app_web/live/page_live.html.heex
and change it to this.
<.flash_group flash={@flash} />
<div
class="h-full w-full px-4 py-10 flex justify-center sm:px-6 sm:py-28 lg:px-8 xl:px-28 xl:py-32"
>
<div
class="flex justify-center items-center mx-auto max-w-xl w-[50vw] lg:mx-0"
>
<div class="space-y-12">
<div class="border-gray-900/10 pb-12">
<h2 class="text-base font-semibold leading-7 text-gray-900">
Image Classification
</h2>
<p class="mt-1 text-sm leading-6 text-gray-600">
Do simple classification with this
<a
href="https://hexdocs.pm/phoenix_live_view/Phoenix.LiveView.html"
class="font-mono font-medium text-sky-500"
>LiveView</a
>
demo, powered by
<a
href="https://github.com/elixir-nx/bumblebee"
class="font-mono font-medium text-sky-500"
>Bumblebee</a
>.
</p>
<!-- File upload section -->
<div class="mt-10 grid grid-cols-1 gap-x-6 gap-y-8 sm:grid-cols-6">
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
phx-drop-target="{@uploads.image_list.ref}"
>
<div class="text-center">
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list}
class="hidden" /> Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">
PNG, JPG, GIF up to 5MB
</p>
</div>
</div>
</div>
</div>
<!-- Prediction text -->
<div
class="mt-6 flex space-x-1.5 items-center font-bold text-gray-900 text-xl"
>
<span>Description: </span>
<!-- Spinner -->
<%= if @upload_running? do %>
<div role="status">
<div
class="relative w-6 h-6 animate-spin rounded-full bg-gradient-to-r from-purple-400 via-blue-500 to-red-400 "
>
<div
class="absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2 w-3 h-3 bg-gray-200 rounded-full border-2 border-white"
></div>
</div>
</div>
<% else %> <%= if @label do %>
<span class="text-gray-700 font-light"><%= @label %></span>
<% else %>
<span class="text-gray-300 font-light">Waiting for image input.</span>
<% end %> <% end %>
</div>
</div>
</div>
</div>
</div>
In these changes,
we've added the output of the model in the form of text.
We are rendering a spinner
if the :upload_running?
socket assign is set to true.
Otherwise,
we add the :label
, which holds the prediction made by the model.
You may have also noticed that
we've changed the phx
event handlers
to noop
.
This is simply to simplify the LiveView
.
Head over to lib/app_web/live/page_live.ex
.
You can now remove the "validate"
, "save"
and "remove-selected"
handlers,
because we're not going to be needing them.
Replace them with this handler:
@impl true
def handle_event("noop", _params, socket) do
{:noreply, socket}
end
And that's it! Our app is now functional 🎉.
If you run the app, you can drag and drop or select an image. After this, a task will be spawned that will run the model against the image that was submitted.
Once a prediction is made, display it!
You can and should try other models.
ResNet-50
is just one of the many that are supported by Bumblebee
.
You can see the supported models at https://github.com/elixir-nx/bumblebee#model-support.
To keep the app as simple as possible, we are receiving the image from the person as is. Although we are processing the image, we are doing it so it is processable by the model.
We have to understand that:
- in most cases, full-resolution images are not necessary,
because neural networks work on much smaller inputs
(e.g.
ResNet-50
works with224px x 224px
images). This means that a lot of data is unnecessarily uploaded over the network, increasing workload on the server to potentially downsize a large image. - decoding an image requires an additional package, meaning more work on the server.
We can avoid both of these downsides by moving this work to the client.
We can leverage the
Canvas API
to decode and downsize this image on the client-side,
reducing server workload.
You can see an example implementation of this technique
in Bumblebee
's repository
at https://github.com/elixir-nx/bumblebee/blob/main/examples/phoenix/image_classification.exs
However, since we are not using JavaScript
for anything,
we can (and should!) properly downsize our images
so they better fit the training dataset of the model we use.
This will allow the model to process faster
since larger images carry over more data that is ultimately unnecessary
for models to make predictions.
Open
lib/app_web/live/page_live.ex
,
find the handle_progress/3
function
and change resize the image before processing it.
file_binary = File.read!(meta.path)
# Get image and resize
# This is dependant on the resolution of the model's dataset.
# In our case, we want the width to be closer to 640, whilst maintaining aspect ratio.
width = 640
{:ok, thumbnail_vimage} = Vix.Vips.Operation.thumbnail(meta.path, width, size: :VIPS_SIZE_DOWN)
# Pre-process it
{:ok, tensor} = pre_process_image(thumbnail_vimage)
#...
We are using
Vix.Vips.Operation.thumbnail/3
to resize our image to a fixed width
whilst maintaining aspect ratio.
The width
variable can be dependent on the model that you use.
For example, ResNet-50
is trained on 224px224
pictures,
so you may want to resize the image to this width.
Note: We are using the
thumbnail/3
function instead ofresize/3
because it's much faster.
Check https://github.com/libvips/libvips/wiki/HOWTO----Image-shrinking to know why.
Although our app is functional, we can make it better. 🎨
In order to better control user input, we should add a limit to the size of the image that is being uploaded. It will be easier on our server and ultimately save costs.
Let's add a cap of 5MB
to our app!
Fortunately for you, this is super simple!
You just need to add the max_file_size
to the allow_uploads/2
function
when mounting the LiveView
!
def mount(_params, _session, socket) do
{:ok,
socket
|> assign(label: nil, upload_running?: false, task_ref: nil)
|> allow_upload(:image_list,
accept: ~w(image/*),
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1,
chunk_size: 64_000,
max_file_size: 5_000_000 # add this
)}
end
And that's it!
The number is in bytes
,
hence why we set it as 5_000_000
.
In case a person uploads an image that is too large, we should show this feedback to the person!
For this, we can leverage the
upload_errors/2
function.
This function will return the entry errors for an upload.
We need to add a handler for one of these errors to show it first.
Head over lib/app_web/live/page_live.ex
and add the following line.
def error_to_string(:too_large), do: "Image too large. Upload a smaller image up to 5MB."
Now, add the following section below the upload form
inside lib/app_web/live/page_live.html.heex
.
<!-- Show errors -->
<%= for entry <- @uploads.image_list.entries do %>
<div class="mt-2">
<%= for err <- upload_errors(@uploads.image_list, entry) do %>
<div class="rounded-md bg-red-50 p-4 mb-2">
<div class="flex">
<div class="flex-shrink-0">
<svg
class="h-5 w-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.28 7.22a.75.75 0 00-1.06 1.06L8.94 10l-1.72 1.72a.75.75 0 101.06 1.06L10 11.06l1.72 1.72a.75.75 0 101.06-1.06L11.06 10l1.72-1.72a.75.75 0 00-1.06-1.06L10 8.94 8.28 7.22z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class="ml-3">
<h3 class="text-sm font-medium text-red-800">
<%= error_to_string(err) %>
</h3>
</div>
</div>
</div>
<% end %>
</div>
<% end %>
We are iterating over the errors returned by upload_errors/2
and invoking error_to_string/1
,
which we've just defined in our LiveView
.
Now, if you run the app and try to upload an image that is too large, an error will show up.
Awesome! 🎉
As of now, even though our app predicts the given images, it is not showing a preview of the image the person submitted. Let's fix this 🛠️.
Let's add a new socket assign variable
pertaining to the base64 representation
of the image in lib/app_web/live_page/live.ex
|> assign(label: nil, upload_running?: false, task_ref: nil, image_preview_base64: nil)
We've added image_preview_base64
as a new socket assign,
initializing it as nil
.
Next, we need to read the file while consuming it, and properly update the socket assign so we can show it to the person.
In the same file,
change the handle_progress/3
function to the following.
def handle_progress(:image_list, entry, socket) when entry.done? do
# Consume the entry and get the tensor to feed to classifier
%{tensor: tensor, file_binary: file_binary} = consume_uploaded_entry(socket, entry, fn %{} = meta ->
file_binary = File.read!(meta.path)
{:ok, vimage} = Vix.Vips.Image.new_from_file(meta.path)
{:ok, tensor} = pre_process_image(vimage)
{:ok, %{tensor: tensor, file_binary: file_binary}}
end)
# Create an async task to classify the image
task = Task.Supervisor.async(App.TaskSupervisor, fn -> Nx.Serving.batched_run(ImageClassifier, tensor) end)
# Encode the image to base64
base64 = "data:image/png;base64, " <> Base.encode64(file_binary)
# Update socket assigns to show spinner whilst task is running
{:noreply, assign(socket, upload_running?: true, task_ref: task.ref, image_preview_base64: base64)}
end
We're using File.read!/1
to retrieve the binary representation of the image that was uploaded.
We use Base.encode64/2
to encode this file binary
and assign the newly created image_preview_base64
socket assign
with this base64 representation of the image.
Now, all that's left to do
is to render the image on our view.
In lib/app_web/live/page_live.html.heex
,
locate the line:
<div class="text-center"></div>
We are going to update this <div>
to show the image with the image_preview_base64
socket assign.
<div class="text-center">
<!-- Show image preview -->
<%= if @image_preview_base64 do %>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list} class="hidden" />
<img src="{@image_preview_base64}" />
</label>
</form>
<% else %>
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list} class="hidden" />
Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">PNG, JPG, GIF up to 5MB</p>
<% end %>
</div>
As you can see,
we are checking if @image_preview_base64
is defined.
If so, we simply show the image with it as src
😊.
Now, if you run the application, you'll see that after dragging the image, it is previewed and shown to the person!
Maybe you weren't happy with the results from this model.
That's fair. ResNet-50
is a smaller, "older" model compared to other
image captioning/classification models.
What if you wanted to use others?
Well, as we've mentioned before,
Bumblebee
uses
Transformer models from HuggingFace
.
To know if one is supported
(as shown in Bumblebee
's docs),
we need to check the config.json
file
in the model repository
and copy the class name under "architectures"
and search it on Bumblebee
's codebase.
For example,
here's one of the more popular image captioning models -
Salesforce's BLIP
-
https://huggingface.co/Salesforce/blip-image-captioning-large/blob/main/config.json.
If you visit Bumblebee
's codebase
and search for the class name,
you'll find it is supported.
Awesome! Now we can use it!
If you dig around Bumblebee
's docs as well
(https://hexdocs.pm/bumblebee/Bumblebee.Vision.html#image_to_text/5),
you'll see that we've got to use image_to_text/5
with this model.
It needs a tokenizer
, featurizer
and a generation-config
so we can use it.
Let's do it!
Head over to lib/app/application.ex
,
and change the serving/0
function.
def serving do
{:ok, model_info} = Bumblebee.load_model({:hf, "Salesforce/blip-image-captioning-base"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "Salesforce/blip-image-captioning-base"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Salesforce/blip-image-captioning-base"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"})
Bumblebee.Vision.image_to_text(model_info, featurizer, tokenizer, generation_config,
compile: [batch_size: 10],
defn_options: [compiler: EXLA]
)
end
As you can see, we're using the repository name of BLIP
's model
from the HuggingFace website.
If you run mix phx.server
,
you'll see that it will download the new models,
tokenizers, featurizer and configs to run the model.
|======================================================================| 100% (989.82 MB)
[info] TfrtCpuClient created.
|======================================================================| 100% (711.39 KB)
[info] Running AppWeb.Endpoint with cowboy 2.10.0 at 127.0.0.1:4000 (http)
[info] Access AppWeb.Endpoint at http://localhost:4000
[watch] build finished, watching for changes...
You may think we're done here. But we are not! ✋
The destructuring of the output of the model may not be the same.
If you try to submit a photo,
you'll get this error:
no match of right hand side value:
%{results: [%{text: "a person holding a large blue ball on a beach"}]}
This means that we need to make some changes when parsing the output of the model 😀.
Head over to lib/app_web/live/page_live.ex
and change the handle_info/3
function
that is called after the async task is completed.
def handle_info({ref, result}, %{assigns: %{task_ref: ref}} = socket) do
Process.demonitor(ref, [:flush])
%{results: [%{text: label}]} = result # change this line
{:noreply, assign(socket, label: label, upload_running?: false)}
end
As you can see, we are now correctly destructuring the result from the model. And that's it!
If you run mix phx.server
,
you'll see that we got far more accurate results!
Awesome! 🎉
Note
Be aware that BLIP
is a much larger model than ResNet-50
.
There are more accurate and even larger models out there
(e.g:
blip-image-captioning-large
,
the larger version of the model we've just used).
This is a balancing act: the larger the model, the longer a prediction may take
and more resources your server will need to have to handle this heavier workload.
Warning
We've created a small module that allows you to have multiple models cached and downloaded and keep this logic contained.
For this, check the deployment guide
.
There are a few considerations you may want to have
before considering deploying this.
Luckily for you,
we've created a small document
that will guide you through deploying this app in fly.io
!
Check the deployment.md
file for more information.
Warning
This section assumes you've made the changes made in the previous section.
Therefore, you should follow the instructions in
7. How do I deploy this thing?
and come back after you're done.
We have a fully functioning application that predicts images. Now we can add some cool touches to show the person some examples if they are inactive.
For this, we are going to need to make three changes.
- create a hook in the client (
Javascript
) to send an event when there's inactivity after a given number of seconds. - change
page_live.ex
LiveView to accommodate this new event. - change the view m
page_live.html.heex
to show these changes to the person.
Let's go over each one!
We are going to detect the inactivity of the person
with some Javascript
code.
Head over to assets/js/app.js
and change it to the following.
// If you want to use Phoenix channels, run `mix help phx.gen.channel`
// to get started and then uncomment the line below.
// import "./user_socket.js"
// You can include dependencies in two ways.
//
// The simplest option is to put them in assets/vendor and
// import them using relative paths:
//
// import "../vendor/some-package.js"
//
// Alternatively, you can `npm install some-package --prefix assets` and import
// them using a path starting with the package name:
//
// import "some-package"
//
// Include phoenix_html to handle method=PUT/DELETE in forms and buttons.
import "phoenix_html";
// Establish Phoenix Socket and LiveView configuration.
import { Socket } from "phoenix";
import { LiveSocket } from "phoenix_live_view";
import topbar from "../vendor/topbar";
// Hooks to track inactivity
let Hooks = {};
Hooks.ActivityTracker = {
mounted() {
// Set the inactivity duration in milliseconds
const inactivityDuration = 8000; // 8 seconds
// Set a variable to keep track of the timer and if the process to predict example image has already been sent
let inactivityTimer;
let processHasBeenSent = false;
let ctx = this;
// Function to reset the timer
function resetInactivityTimer() {
// Clear the previous timer
clearTimeout(inactivityTimer);
// Start a new timer
inactivityTimer = setTimeout(() => {
// Perform the desired action after the inactivity duration
// For example, send a message to the Elixir process using Phoenix Socket
if (!processHasBeenSent) {
processHasBeenSent = true;
ctx.pushEvent("show_examples", {});
}
}, inactivityDuration);
}
// Call the function to start the timer initially
resetInactivityTimer();
// Reset the timer whenever there is user activity
document.addEventListener("mousemove", resetInactivityTimer);
document.addEventListener("keydown", resetInactivityTimer);
},
};
let csrfToken = document
.querySelector("meta[name='csrf-token']")
.getAttribute("content");
let liveSocket = new LiveSocket("/live", Socket, {
hooks: Hooks,
params: { _csrf_token: csrfToken },
});
// Show progress bar on live navigation and form submits
topbar.config({ barColors: { 0: "#29d" }, shadowColor: "rgba(0, 0, 0, .3)" });
window.addEventListener("phx:page-loading-start", (_info) => topbar.show(300));
window.addEventListener("phx:page-loading-stop", (_info) => topbar.hide());
// connect if there are any LiveViews on the page
liveSocket.connect();
// expose liveSocket on window for web console debug logs and latency simulation:
// >> liveSocket.enableDebug()
// >> liveSocket.enableLatencySim(1000) // enabled for duration of browser session
// >> liveSocket.disableLatencySim()
window.liveSocket = liveSocket;
- we have added a
Hooks
variable, with a propertyActivityTracker
. This hook has themounted()
function that is executed when the component that is hooked with this hook is mounted. You can find more information at https://hexdocs.pm/phoenix_live_view/js-interop.html. - inside the
mounted()
function, we create aresetInactivityTimer()
function that is executed every time the mouse moves (mousemove
event) and a key is pressed(keydown
). This function resets the timer that is run whilst there is a lack of inactivity. - if the person is inactive for 8 seconds,
we create an event
"show_examples"
. We will create a handler in the LiveView to handle this event later. - we add the
Hooks
variable to thehooks
property when initializing thelivesocket
.
And that's it!
For this hook to actually be executed, we need to create a component that uses it inside our view file.
For this, we can simply create a hidden component
on top of the lib/app_web/live/page_live.html.heex
file.
Add the following hidden component.
<div class="hidden" id="tracker_el" phx-hook="ActivityTracker" />
We use the phx-hook
attribute
to bind the hook we've created in our app.js
file
to the component so it's executed.
When this component is mounted,
the mounted()
function inside the hook
is executed.
And that's it! 👏
Your app won't work yet because
we haven't created a handler
to handle the "show_examples"
event.
Let's do that right now!
Now that we have our client sorted,
let's head over to our LiveView
at lib/app_web/live/page_live.ex
and make the needed changes!
Before anything, let's add socket assigns that we will need throughout our adventure! On top of the file, change the socket assigns to the following.
socket
|> assign(
# Related to the file uploaded by the user
label: nil,
upload_running?: false,
task_ref: nil,
image_preview_base64: nil,
# Related to the list of image examples
example_list_tasks: [],
example_list: [],
display_list?: false
)
We've added three new assigns:
example_list_tasks
is a list of the async tasks that are created for each example image.example_list
is a list of the example images with their respective predictions.display_list?
is a boolean that tells us if the list is to be shown or not.
Awesome! Let's continue.
As we've mentioned before,
we need to create a handler for our "show_examples"
event.
Add the following function to the file.
@image_width 640
def handle_event("show_examples", _data, socket) do
# Only run if the user hasn't uploaded anything
if(is_nil(socket.assigns.task_ref)) do
# Retrieves a random image from Picsum with a given `image_width` dimension
random_image = "https://picsum.photos#{@image_width}/#{@image_width}"
# Spawns prediction tasks for example image from random Picsum image
tasks = for _ <- 1..2 do
{:req, body} = {:req, Req.get!(random_image).body}
predict_example_image(body)
end
# List to change `example_list` socket assign to show skeleton loading
display_example_images = Enum.map(tasks, fn obj -> %{predicting?: true, ref: obj.ref} end)
# Updates the socket assigns
{:noreply, assign(socket, example_list_tasks: tasks, example_list: display_example_images)}
else
{:noreply, socket}
end
end
Warning
We are using the
req
package
to download the file binary from the URL.
Make sure to install it in the mix.exs
file.
- we are using the Picsum API,
an awesome image API with lots of photos!
They provide a
/random
URL that yields a random photo. In this URL we can inclusively define the dimensions we want! That's what we're doing in the first line of the function. We are using a module constant@image_width 640
on top of the file, so add that to the top of the file. This function is relevant because we preferably want to deal with images that are in the same resolution as the dataset the model was trained in. - we are creating two async tasks that retrieve the binary of the image
and pass it on to a
predict_example_image/1
function (we will create this function next). - the two tasks that we've created are in an array
tasks
. We create another arraydisplay_example_images
with the same number of elements astasks
, which will have two properties:predicting
, meaning if the image is being predicted by the model; andref
, the reference of the task. - we assign the
tasks
array to theexample_list_tasks
socket assign anddisplay_example_images
array to theexample_list
array. Soexample_list
will temporarily hold objects with:predicting
and:ref
properties whilst the model is being executed.
As we've just mentioned,
we are making use of a function called
predict_example_image/1
to make predictions
of a given file binary.
Let's implement it now! In the same file, add:
def predict_example_image(body) do
with {:vix, {:ok, img_thumb}} <-
{:vix, Vix.Vips.Operation.thumbnail_buffer(body, @image_width)},
{:pre_process, {:ok, t_img}} <-
{:pre_process, pre_process_image(img_thumb)} do
# Create an async task to classify the image from Picsum
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, t_img)
end)
|> Map.merge(%{base64_encoded_url: "data:image/png;base64, " <> Base.encode64(body)})
else
{stage, error} -> {stage, error}
end
end
For the body of the image to be executed by the model, it needs to go through some pre-processing.
- we are using
thumbnail_buffer/3
to make sure it's properly resized and then feeding the result to our own implementedpre_process_image/1
function so it can be converted to a parseable tensor by the model. - after these two operations are successfully completed, we spawn two async tasks (like we've done before) and feed it to the model. We add the base64-encoded image to the return value so it can later be shown to the person.
- if these operations fail, we return an error.
Great job! 👏
Our example images async tasks have successfully been created and are on their way to the model!
Now we need to handle these newly created async tasks
once they are completed.
As we know, we are handling our async tasks completion
in the def handle_info({ref, result}, %{assigns: %{task_ref: ref}} = socket)
function.
Let's change it like so.
def handle_info({ref, result}, %{assigns: assigns} = socket) do
# Flush async call
Process.demonitor(ref, [:flush])
# You need to change how you destructure the output of the model depending
# on the model you've chosen for `prod` and `test` envs on `models.ex`.)
label =
case Application.get_env(:app, :use_test_models, false) do
true ->
App.Models.extract_test_label(result)
# coveralls-ignore-start
false ->
App.Models.extract_prod_label(result)
# coveralls-ignore-stop
end
cond do
# If the upload task has finished executing, we update the socket assigns.
Map.get(assigns, :task_ref) == ref ->
{:noreply, assign(socket, label: label, upload_running?: false)}
# If the example task has finished executing, we upload the socket assigns.
img = Map.get(assigns, :example_list_tasks) |> Enum.find(&(&1.ref == ref)) ->
# Update the element in the `example_list` enum to turn "predicting?" to `false`
updated_example_list = Map.get(assigns, :example_list)
|> Enum.map(fn obj ->
if obj.ref == img.ref do
Map.put(obj, :base64_encoded_url, img.base64_encoded_url)
|> Map.put(:label, label)
|> Map.put(:predicting?, false)
else
obj
end end)
{:noreply,
assign(socket,
example_list: updated_example_list,
upload_running?: false,
display_list?: true
)}
end
end
The only change we've made is that
we've added a cond
flow structure.
We are essentially checking
if the task reference that has been completed
is from the uploaded image from the person (:task_ref
socket assign)
or from an example image (inside the :example_list_tasks
socket assign list).
If it's the latter,
we retrieve we are updating
the example_list
socket assign list
with the prediction (:label
),
the base64-encoded image from the task list (:base64_encoded_url
)
and setting the :predicting
property to false
.
And that's it! Great job! 🥳
Now that we've made all the necessary changes to our LiveView, we need to update our view so it reflects them!
Head over to lib/app_web/live/page_live.html.heex
and change it to the following piece of code.
<div class="hidden" id="tracker_el" phx-hook="ActivityTracker" />
<div
class="h-full w-full px-4 py-10 flex justify-center sm:px-6 sm:py-24 lg:px-8 xl:px-28 xl:py-32"
>
<div class="flex flex-col justify-start">
<div class="flex justify-center items-center w-full">
<div class="2xl:space-y-12">
<div class="mx-auto max-w-2xl lg:text-center">
<p>
<span
class="rounded-full w-fit bg-brand/5 px-2 py-1 text-[0.8125rem] font-medium text-center leading-6 text-brand"
>
<a
href="https://hexdocs.pm/phoenix_live_view/Phoenix.LiveView.html"
target="_blank"
rel="noopener noreferrer"
>
🔥 LiveView
</a>
+
<a
href="https://github.com/elixir-nx/bumblebee"
target="_blank"
rel="noopener noreferrer"
>
🐝 Bumblebee
</a>
</span>
</p>
<p
class="mt-2 text-3xl font-bold tracking-tight text-gray-900 sm:text-4xl"
>
Caption your image!
</p>
<h3 class="mt-6 text-lg leading-8 text-gray-600">
Upload your own image (up to 5MB) and perform image captioning with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
Elixir
</a>
!
</h3>
<p class="text-lg leading-8 text-gray-400">
Powered with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
HuggingFace🤗
</a>
transformer models, you can run this project locally and perform
machine learning tasks with a handful lines of code.
</p>
</div>
<div class="border-gray-900/10">
<!-- File upload section -->
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
phx-drop-target="{@uploads.image_list.ref}"
>
<div class="text-center">
<!-- Show image preview -->
<%= if @image_preview_base64 do %>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list}
class="hidden" />
<img src="{@image_preview_base64}" />
</label>
</form>
<% else %>
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list}
class="hidden" /> Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">
PNG, JPG, GIF up to 5MB
</p>
<% end %>
</div>
</div>
</div>
</div>
<!-- Show errors -->
<%= for entry <- @uploads.image_list.entries do %>
<div class="mt-2">
<%= for err <- upload_errors(@uploads.image_list, entry) do %>
<div class="rounded-md bg-red-50 p-4 mb-2">
<div class="flex">
<div class="flex-shrink-0">
<svg
class="h-5 w-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.28 7.22a.75.75 0 00-1.06 1.06L8.94 10l-1.72 1.72a.75.75 0 101.06 1.06L10 11.06l1.72 1.72a.75.75 0 101.06-1.06L11.06 10l1.72-1.72a.75.75 0 00-1.06-1.06L10 8.94 8.28 7.22z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class="ml-3">
<h3 class="text-sm font-medium text-red-800">
<%= error_to_string(err) %>
</h3>
</div>
</div>
</div>
<% end %>
</div>
<% end %>
<!-- Prediction text -->
<div
class="flex mt-2 space-x-1.5 items-center font-bold text-gray-900 text-xl"
>
<span>Description: </span>
<!-- Spinner -->
<%= if @upload_running? do %>
<div role="status">
<div
class="relative w-6 h-6 animate-spin rounded-full bg-gradient-to-r from-purple-400 via-blue-500 to-red-400 "
>
<div
class="absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2 w-3 h-3 bg-gray-200 rounded-full border-2 border-white"
></div>
</div>
</div>
<% else %> <%= if @label do %>
<span class="text-gray-700 font-light"><%= @label %></span>
<% else %>
<span class="text-gray-300 font-light">Waiting for image input.</span>
<% end %> <% end %>
</div>
</div>
</div>
<!-- Examples -->
<%= if @display_list? do %>
<div class="flex flex-col">
<h3
class="mt-10 text-xl lg:text-center font-light tracking-tight text-gray-900 lg:text-2xl"
>
Examples
</h3>
<div class="flex flex-row justify-center my-8">
<div
class="mx-auto grid max-w-2xl grid-cols-1 gap-x-6 gap-y-20 sm:grid-cols-2"
>
<%= for example_img <- @example_list do %>
<!-- Loading skeleton if it is predicting -->
<%= if example_img.predicting? == true do %>
<div
role="status"
class="flex items-center justify-center w-full h-full max-w-sm bg-gray-300 rounded-lg animate-pulse"
>
<svg
class="w-10 h-10 text-gray-200 dark:text-gray-600"
aria-hidden="true"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
viewBox="0 0 20 18"
>
<path
d="M18 0H2a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h16a2 2 0 0 0 2-2V2a2 2 0 0 0-2-2Zm-5.5 4a1.5 1.5 0 1 1 0 3 1.5 1.5 0 0 1 0-3Zm4.376 10.481A1 1 0 0 1 16 15H4a1 1 0 0 1-.895-1.447l3.5-7A1 1 0 0 1 7.468 6a.965.965 0 0 1 .9.5l2.775 4.757 1.546-1.887a1 1 0 0 1 1.618.1l2.541 4a1 1 0 0 1 .028 1.011Z"
/>
</svg>
<span class="sr-only">Loading...</span>
</div>
<% else %>
<div>
<img
id="{example_img.base64_encoded_url}"
src="{example_img.base64_encoded_url}"
class="rounded-2xl object-cover"
/>
<h3 class="mt-1 text-lg leading-8 text-gray-900 text-center">
<%= example_img.label %>
</h3>
</div>
<% end %> <% end %>
</div>
</div>
</div>
<% end %>
</div>
</div>
We've made two changes.
- we've added some text to better introduce our application at the top of the page.
- added a section to show the example images list.
This section is only rendered if the
display_list?
socket assign is set totrue
. If so, we iterate over theexample_list
socket assign list and show a loading skeleton if the image is:predicting
. If not, it means the image has already been predicted, and we show the base64-encoded image like we do with the image uploaded by the person.
And that's it! 🎉
While our example list is being correctly rendered, we are using additional CPU to base64 encode our images so they can be shown to the person.
Initially, we did this because
https://picsum.photos
resolves into a different URL every time it is called.
This means that the image that was fed into the model
would be different from the one shown in the example list
if we were to use this URL in our view.
To fix this, we need to follow the redirection when the URL is resolved.
Note
We can do this with
Finch
,
if we wanted to.
We could do something like def rand_splash do
%{scheme: scheme, host: host, path: path} =
Finch.build(:get, "https://picsum.photos")
|> Finch.request!(MyFinch)
|> Map.get(:headers)
|> Enum.filter(fn {a, _b} -> a == "location" end)
|> List.first()
|> elem(1)
|> URI.parse()
scheme <> "://" <> host <> path
end
And then call it, like so.
App.rand_splash()
# https://images.unsplash.com/photo-1694813646634-9558dc7960e3
Because we are already using
req
,
let's make use of it
instead of adding additional dependencies.
Let's first add a function that will do this
in lib/app_web/live/page_live.ex
.
Add the following piece of code
at the end of the file.
defp track_redirected(url) do
# Create request
req = Req.new(url: url)
# Add tracking properties to req object
req = req
|> Req.Request.register_options([:track_redirected])
|> Req.Request.prepend_response_steps(track_redirected: &track_redirected_uri/1)
# Make request
{:ok, response} = Req.request(req)
# Return the final URI
%{url: URI.to_string(response.private.final_uri), body: response.body}
end
defp track_redirected_uri({request, response}) do
{request, %{response | private: Map.put(response.private, :final_uri, request.url)}}
end
This function adds properties to the request object
and tracks the redirection.
It will add a URI
object inside private.final_uri
.
This function returns
the body
of the image
and the final url
it is resolved to
(the URL of the image).
Now all we need to do is use this function!
Head over to the handle_event("show_examples"...
function
and change the loop to the following.
tasks = for _ <- 1..2 do
%{url: url, body: body} = track_redirected(random_image)
predict_example_image(body, url)
end
We are making use of track_redirected/1
,
the function we've just created.
We pass both body
and url
to predict_example_image/1
,
which we will now change.
def predict_example_image(body, url) do
with {:vix, {:ok, img_thumb}} <-
{:vix, Vix.Vips.Operation.thumbnail_buffer(body, @image_width)},
{:pre_process, {:ok, t_img}} <- {:pre_process, pre_process_image(img_thumb)} do
# Create an async task to classify the image from Picsum
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, t_img)
end)
|> Map.merge(%{url: url})
else
{:vix, {:error, msg}} -> {:error, msg}
{:pre_process, {:error, msg}} -> {:error, msg}
end
end
Instead of using base64_encoded_url
,
we are now using the url
we've acquired.
The last step we need to do in our LiveView
is to finally use this url
in handle_info/3
.
def handle_info({ref, result}, %{assigns: assigns} = socket) do
Process.demonitor(ref, [:flush])
label =
case Application.get_env(:app, :use_test_models, false) do
true ->
App.Models.extract_test_label(result)
false ->
App.Models.extract_prod_label(result)
end
cond do
Map.get(assigns, :task_ref) == ref ->
{:noreply, assign(socket, label: label, upload_running?: false)}
img = Map.get(assigns, :example_list_tasks) |> Enum.find(&(&1.ref == ref)) ->
updated_example_list = Map.get(assigns, :example_list)
|> Enum.map(fn obj ->
if obj.ref == img.ref do
obj
|> Map.put(:url, img.url) # change here
|> Map.put(:label, label)
|> Map.put(:predicting?, false)
else
obj
end end)
{:noreply,
assign(socket,
example_list: updated_example_list,
upload_running?: false,
display_list?: true
)}
end
end
And that's it!
The last thing we need to do is change our view
so it uses the :url
parameter
instead of the obsolete :base64_encoded_url
.
Head over to lib/app_web/live/page_live.html.heex
and change the <img>
being shown in the example list
so it uses the :url
parameter.
<img
id="{example_img.url}"
src="{example_img.url}"
class="rounded-2xl object-cover"
/>
And we're done! 🎉
We are now rendering the image on the client through the URL the Picsum API resolves into instead of having the LiveView server encoding the image. Therefore, we're saving some CPU to the thing that matters the most: running our model.
Now let's see our application in action! We are expecting the examples to be shown after 8 seconds of inactivity. If the person is inactive for this time duration, we fetch a random image from Picsum API and feed it to our model!
You should see different images every time you use the app. Isn't that cool? 😎
Our app is shaping up quite nicely! As it stands, it's an application that does inference on images. However, it doesn't save them.
Let's expand our application so it has a database where image classification is saved/persisted!
We'll use Postgres
for this.
Typically, when you create a new Phoenix
project
with mix phx.new
,
a Postgres
database will be automatically created.
Because we didn't do this,
we'll have to configure this ourselves.
Let's do it!
We'll install all the needed dependencies first.
In mix.exs
, add the following snippet
to the deps
section.
# HTTP Request
{:httpoison, "~> 2.2"},
{:mime, "~> 2.0.5"},
{:ex_image_info, "~> 0.2.4"},
# DB
{:phoenix_ecto, "~> 4.4"},
{:ecto_sql, "~> 3.10"},
{:postgrex, ">= 0.0.0"},
-
httpoison
,mime
andex_image_info
are used to make HTTP requests, get the content type and information from an image file, respectively. These will be needed to upload a given image toimgup
, by making multipart requests. -
phoenix_ecto
,ecto_sql
andpostgrex
are needed to properly configure our driver in Elixir that will connect to a Postgres database, in which we will persist data.
Run mix deps.get
to install these dependencies.
Now let's create the needed files
to properly connect to a Postgres relational database.
Start by going to lib/app
and create repo.ex
.
defmodule App.Repo do
use Ecto.Repo,
otp_app: :app,
adapter: Ecto.Adapters.Postgres
end
This module will be needed in our configuration files so our app knows where the database is.
Next, in lib/app/application.ex
,
add the following line to the children
array
in the supervision tree.
children = [
AppWeb.Telemetry,
App.Repo, # add this line
{Phoenix.PubSub, name: App.PubSub},
...
]
Awesome! 🎉
Now let's head over to the files inside the config
folder.
In config/config.exs
,
add these lines.
config :app,
ecto_repos: [App.Repo],
generators: [timestamp_type: :utc_datetime]
We are referring the module we've previously created (App.Repo
)
to ecto_repos
so ecto
knows where the configuration
for the database is located.
In config/dev.exs
,
add:
config :app, App.Repo,
username: "postgres",
password: "postgres",
hostname: "localhost",
database: "app_dev",
stacktrace: true,
show_sensitive_data_on_connection_error: true,
pool_size: 10
We are defining the parameters of the database that is used during development.
In config/runtime.exs
,
add:
if config_env() == :prod do
database_url =
System.get_env("DATABASE_URL") ||
raise """
environment variable DATABASE_URL is missing.
For example: ecto://USER:PASS@HOST/DATABASE
"""
maybe_ipv6 = if System.get_env("ECTO_IPV6") in ~w(true 1), do: [:inet6], else: []
config :app, App.Repo,
# ssl: true,
url: database_url,
pool_size: String.to_integer(System.get_env("POOL_SIZE") || "10"),
socket_options: maybe_ipv6
# ...
We are configuring the runtime database configuration..
In config/test.exs
,
add:
config :app, App.Repo,
username: "postgres",
password: "postgres",
hostname: "localhost",
database: "app_test#{System.get_env("MIX_TEST_PARTITION")}",
pool: Ecto.Adapters.SQL.Sandbox,
pool_size: 10
Here we're defining the database used during testing.
Now let's create a migration file to create our database table.
In priv/repo/migrations/
, create a file
called 20231204092441_create_images.exs
(or any other timestamp string)
with the following piece of code.
defmodule App.Repo.Migrations.CreateImages do
use Ecto.Migration
def change do
create table(:images) do
add :url, :string
add :description, :string
add :width, :integer
add :height, :integer
timestamps(type: :utc_datetime)
end
end
end
And that's it! Those are the needed files that our application needs to properly connect and persist data into the Postgres database.
You can now run mix ecto.create
and mix ecto.migrate
to create the database
and the "images"
table.
For now, let's create a simple table "images"
in our database
that has the following properties:
description
: the description of the image from the model.width
: width of the image.height
: height of the image.url
: public URL where the image is publicly hosted.
With this in mind, let's create a new file!
In lib/app/
, create a file called image.ex
.
defmodule App.Image do
use Ecto.Schema
alias App.{Image, Repo}
@primary_key {:id, :id, autogenerate: true}
schema "images" do
field(:description, :string)
field(:width, :integer)
field(:url, :string)
field(:height, :integer)
timestamps(type: :utc_datetime)
end
def changeset(image, params \\ %{}) do
image
|> Ecto.Changeset.cast(params, [:url, :description, :width, :height])
|> Ecto.Changeset.validate_required([:url, :description, :width, :height])
end
@doc """
Uploads the given image to S3
and adds the image information to the database.
"""
def insert(image) do
%Image{}
|> changeset(image)
|> Repo.insert!()
end
end
We've just created the App.Image
schema
with the aforementioned fields.
We've created changeset/1
, which is used to cast
and validate the properties of a given object
before interacting with the database.
insert/1
receives an object,
runs it through the changeset
and inserts it in the database.
Now that we have our database set up,
let's change some of our code so we persist data into it!
In this section, we'll be working in the lib/app_web/live/page_live.ex
file.
First, let's import App.Image
and create an ImageInfo
struct to hold the information
of the image throughout the process of uploading
and classifying the image.
defmodule AppWeb.PageLive do
use AppWeb, :live_view
alias App.Image # add this import
alias Vix.Vips.Image, as: Vimage
defmodule ImageInfo do
@doc """
General information for the image that is being analysed.
This information is useful when persisting the image to the database.
"""
defstruct [:mimetype, :width, :height, :url, :file_binary]
end
# ...
We are going to be using ImageInfo
in our socket assigns.
Let's add to it when the LiveView is mounting!
|> assign(
label: nil,
upload_running?: false,
task_ref: nil,
image_info: nil, # add this line
image_preview_base64: nil,
example_list_tasks: [],
example_list: [],
display_list?: false
)
When the person uploads an image,
we want to retrieve its info (namely its height, width)
and upload the image to an S3
bucket (we're doing this through imgup
)
so we can populate the :url
field of the schema in the database.
We can retrieve this information while consuming the entry
/uploading the image file.
For this, go to handle_progress(:image_list, entry, socket)
and change the function to the following.
def handle_progress(:image_list, entry, socket) when entry.done? do
# We've changed the object that is returned from `consume_uploaded_entry/3` to return an `image_info` object.
%{tensor: tensor, image_info: image_info} =
consume_uploaded_entry(socket, entry, fn %{} = meta ->
file_binary = File.read!(meta.path)
# Add this line. It uses `ExImageInfo` to retrieve the info from the file binary.
{mimetype, width, height, _variant} = ExImageInfo.info(file_binary)
{:ok, thumbnail_vimage} =
Vix.Vips.Operation.thumbnail(meta.path, @image_width, size: :VIPS_SIZE_DOWN)
{:ok, tensor} = pre_process_image(thumbnail_vimage)
# Add this line. Uploads the image to the S3, which returns the `url` and `compressed url`.
# (we'll implement this function next)
url = Image.upload_image_to_s3(meta.path, mimetype) |> Map.get("url")
# Add this line. We are constructing the image_info object to be returned.
image_info = %ImageInfo{mimetype: mimetype, width: width, height: height, file_binary: file_binary, url: url}
# Return it
{:ok, %{tensor: tensor, image_info: image_info}}
end)
task =
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, tensor)
end)
base64 = "data:image/png;base64, " <> Base.encode64(image_info.file_binary)
# Change this line so `image_info` is defined when the image is uploaded
{:noreply, assign(socket, upload_running?: true, task_ref: task.ref, image_preview_base64: base64, image_info: image_info)}
#else
# {:noreply, socket}
#end
end
Check the comment lines for more explanation on the changes that have bee nmade.
We are using ExImageInfo
to fetch the information from the image
and assigning it to the image_info
socket we defined earlier.
We are also using Image.upload_image_to_s3/2
to upload our image to imgup
.
Let's define this function in lib/app/image.ex
.
def upload_image_to_s3(file_path, mimetype) do
extension = MIME.extensions(mimetype) |> Enum.at(0)
# Upload to Imgup - https://github.com/dwyl/imgup
upload_response =
HTTPoison.post!(
"https://imgup.fly.dev/api/images",
{:multipart,
[
{
:file,
file_path,
{"form-data", [name: "image", filename: "#{Path.basename(file_path)}.#{extension}"]},
[{"Content-Type", mimetype}]
}
]},
[]
)
# Return URL
Jason.decode!(upload_response.body)
end
We're using HTTPoison
to make a multipart request to the imgup
server,
effectively uploading the image to the server.
If the upload is successful, it returns the url
of the uploaded image.
Let's go back to lib/app_web/live/page_live.ex
.
Now that we have image_info
in the socket assigns,
we can use it to insert a row in the "images"
table in the database.
We only want to do this after the model is done running,
so simply change handle_info/2
function
(which is called after the model is done with classifying the image).
cond do
# If the upload task has finished executing, we update the socket assigns.
Map.get(assigns, :task_ref) == ref ->
# Insert image to database
image = %{
url: assigns.image_info.url,
width: assigns.image_info.width,
height: assigns.image_info.height,
description: label
}
Image.insert(image)
# Update socket assigns
{:noreply, assign(socket, label: label, upload_running?: false)}
# ...
In the cond do
statement,
we want to change the one pertaining to the image that is uploaded,
not the example list that is defined below.
We simply create an image
variable with information
that is passed down to Image.insert/1
,
effectively adding the row to the database.
And that's it!
Now every time a person uploads an image
and the model is executed,
we are saving its location (:url
),
information (:width
and :height
)
and the result of the classifying model
(:description
).
🥳
Note
If you're curious and want to see the data in your database,
we recommend using DBeaver
,
an open-source database manager.
You can learn more about it at https://github.com/dwyl/learn-postgresql.
Currently, we are not handling any errors
in case the upload of the image to imgup
fails.
Although this is not critical,
it'd be better if we could show feedback to the person
in case the upload to imgup
fails.
This is good for us as well,
because we can monitor and locate the error faster
if we log the errors.
For this, let's head over to lib/app/image.ex
and update the upload_image_to_s3/2
function we've implemented.
def upload_image_to_s3(file_path, mimetype) do
extension = MIME.extensions(mimetype) |> Enum.at(0)
# Upload to Imgup - https://github.com/dwyl/imgup
upload_response =
HTTPoison.post(
"https://imgup.fly.dev/api/images",
{:multipart,
[
{
:file,
file_path,
{"form-data", [name: "image", filename: "#{Path.basename(file_path)}.#{extension}"]},
[{"Content-Type", mimetype}]
}
]},
[]
)
# Process the response and return error if there was a problem uploading the image
case upload_response do
# In case it's successful
{:ok, %HTTPoison.Response{status_code: 200, body: body}} ->
%{"url" => url, "compressed_url" => _} = Jason.decode!(body)
{:ok, url}
# In case it returns HTTP 400 with specific reason it failed
{:ok, %HTTPoison.Response{status_code: 400, body: body}} ->
%{"errors" => %{"detail" => reason}} = Jason.decode!(body)
{:error, reason}
# In case the request fails for whatever other reason
{:error, %HTTPoison.Error{reason: reason}} ->
{:error, reason}
end
end
As you can see,
we are returning {:error, reason}
if an error occurs,
and providing feedback alongside it.
If it's successful, we return {:ok, url}
.
Because we've just changed this function,
we need to also update def handle_progress(:image_list...
inside lib/app_web/live/page_live.ex
to properly handle this new function output.
We are also introducing a double MIME type check to ensure that only image files are uploaded and processed.
We use GenMagic. It provides supervised and customisable access to libmagic
using a supervised external process.
This gist explains that Magic numbers are the first bits of a file
which uniquely identifies the type of file.
We use the GenMagic server as a daemon; it is started in the Application module.
It is referenced by its name.
When we run perform
, we obtain a map and compare the mime type with the one read by ExImageInfo
.
If they correspond with each other, we continue, or else we stop the process.
On your computer, for this to work locally, you should install the package libmagic-dev
.
Note
Depending on your OS, you may install libmagic
in different ways.
A quick Google search will suffice,
but here are a few resources nonetheless:
- Mac: https://gist.github.com/eparreno/1845561
- Windows: https://github.com/nscaife/file-windows
- Linux: https://zoomadmin.com/HowToInstall/UbuntuPackage/libmagic-dev
Definitely read gen_magic
's installation section in https://github.com/evadne/gen_magic#installation.
You may need to perform additional steps.
You'll need to add gen_magic
to mix.exs
.
This dependency will allow us to access libmagic
through Elixir
.
def deps do
[
{:gen_magic, "~> 1.1.1"}
]
end
In the Application
module, you should add the GenMagic
daemon
(the C lib is loaded once for all and referenced by its name).
#application.ex
children = [
...,
{GenMagic.Server, name: :gen_magic},
]
In the Dockerfile (needed to deploy this app), we will install the libmagic-dev
as well:
RUN apt-get update -y && \
apt-get install -y libstdc++6 openssl libncurses5 locales ca-certificates libmagic-dev\
&& apt-get clean && rm -f /var/lib/apt/lists/*_*
Add the following function in the module App.Image:
@doc """
Check file type via magic number. It uses a GenServer running the `C` lib "libmagic".
"""
def gen_magic_eval(path, accepted_mime) do
GenMagic.Server.perform(:gen_magic, path)
|> case do
{:error, reason} ->
{:error, reason}
{:ok,
%GenMagic.Result{
mime_type: mime,
encoding: "binary",
content: _content
}} ->
if Enum.member?(accepted_mime, mime),
do: {:ok, %{mime_type: mime}},
else: {:error, "Not accepted mime type."}
{:ok, %GenMagic.Result{} = res} ->
require Logger
Logger.warning("⚠️ MIME type error: #{inspect(res)}")
{:error, "Not acceptable."}
end
end
end
In the page_live.ex
module, add the functions:
@doc"""
Use the previous function and eturn the GenMagic reponse from the previous function
"""
def magic_check(path) do
App.Image.gen_magic_eval(path, @accepted_mime)
|> case do
{:ok, %{mime_type: mime}} ->
{:ok, %{mime_type: mime}}
{:error, msg} ->
{:error, msg}
end
end
@doc """
Double-checks the MIME type of uploaded file to ensure that the file
is an image and is not corrupted.
"""
def check_mime(magic_mime, info_mime) do
if magic_mime == info_mime, do: :ok, else: :error
end
We are now ready to double-check the file input
with ExImageInfo
and GenMagic
to ensure the safety of the uploads.
def handle_progress(:image_list, entry, socket) when entry.done? do
# We consume the entry only if the entry is done uploading from the image
# and if consuming the entry was successful.
with %{tensor: tensor, image_info: image_info} <-
consume_uploaded_entry(socket, entry, fn %{path: path} ->
with {:magic, {:ok, %{mime_type: mime}}} <-
{:magic, magic_check(path)},
{:read, {:ok, file_binary}} <-
{:read, File.read(path)},
{:image_info, {mimetype, width, height, _variant}} <-
{:image_info, ExImageInfo.info(file_binary)},
{:check_mime, :ok} <-
{:check_mime, check_mime(mime, mimetype)},
# Get image and resize
{:ok, thumbnail_vimage} <-
Vix.Vips.Operation.thumbnail(path, @image_width, size: :VIPS_SIZE_DOWN),
# Pre-process it
{:ok, tensor} <-
pre_process_image(thumbnail_vimage) do
# Upload image to S3
Image.upload_image_to_s3(path, mimetype)
|> case do
{:ok, url} ->
image_info = %ImageInfo{
mimetype: mimetype,
width: width,
height: height,
file_binary: file_binary,
url: url
}
{:ok, %{tensor: tensor, image_info: image_info}}
# If S3 upload fails, we return error
{:error, reason} ->
{:ok, %{error: reason}}
end
else
{:error, reason} -> {:postpone, %{error: reason}}
end
end) do
# If consuming the entry was successful, we spawn a task to classify the image
# and update the socket assigns
task =
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, tensor)
end)
# Encode the image to base64
base64 = "data:image/png;base64, " <> Base.encode64(image_info.file_binary)
{:noreply,
assign(socket,
upload_running?: true,
task_ref: task.ref,
image_preview_base64: base64,
image_info: image_info
)}
# Otherwise, if there was an error uploading the image, we log the error and show it to the person.
else
%{error: reason} ->
Logger.warning("⚠️ Error uploading image. #{inspect(reason)}")
{:noreply, push_event(socket, "toast", %{message: "Image couldn't be uploaded to S3.\n#{reason}"})}
_ ->
{:noreply, socket}
end
end
Phew! That's a lot! Let's go through the changes we've made.
- we are using the
with
statement to only feed the image to the model for classification in case the upload toimgup
succeeds. We've changed whatconsume_uploaded_entry/3
returns in case the upload fails - we return{:ok, %{error: reason}}
. - in case the upload fails,
we pattern match the
{:ok, %{error: reason}}
object and push a"toast"
event to the Javascript client (we'll implement these changes shortly).
Because we push an event in case the upload fails, we are going to make some changes to the Javascript client. We are going to show a toast with the error when the upload fails.
To show a toast component,
we are going to use
toastify.js
.
Navigate to assets
folder
and run:
pnpm install toastify-js
With this installed, we need to import toastify
styles
in assets/css/app.css
.
@import "../node_modules/toastify-js/src/toastify.css";
All that's left is handle the "toast"
event in assets/js/app.js
.
Add the following snippet of code to do so.
// Hook to show message toast
Hooks.MessageToaster = {
mounted() {
this.handleEvent("toast", (payload) => {
Toastify({
text: payload.message,
gravity: "bottom",
position: "right",
style: {
background: "linear-gradient(to right, #f27474, #ed87b5)",
},
duration: 4000,
}).showToast();
});
},
};
With the payload.message
we're receiving from the LiveView
(remember when we executed push_event/3
in our LiveView?),
we are using it to create a Toastify
object
that is shown in case the upload fails.
And that's it! Quite easy, isn't it? 😉
If imgup
is down or the image that was sent was for example, invalid, an error should be shown, like so.
You may be wondering: which model is most suitable for me?
Depending on the use case,
Bumblebee
supports different models
for different scenarios.
To help you make up your mind,
we've created a guide
that benchmarks some of Bumblebee
-supported models
for image captioning.
Although few models are supported, as they add more models, this comparison table will grow. So any contribution is more than welcome! 🎉
You may check the guide
and all of the code
inside the
_comparison
folder.
In this section, we will focus on implementing a full-text search query through the captions of the images. At the end of this, you'll be able to transcribe audio, create embeddings from the audio transcription and search the closest related image.
Note
This section was kindly implemented and documented by @ndrean. It is based on articles written by Sean Moriarty's published in the Dockyard's blog. Do check him out! 🎉
We can leverage machine learning to greatly improve this search process: we'll look for images whose captions are close in terms of meaning to the search.
In this section, you'll learn how to perform semantic search with machine learning. These techniques are widely used in search engines, including in widespread tools like Elastic Search.
Let's go over the process in detail so we know what to expect.
As it stands, when images are uploaded and captioned, the URL is saved, as well as the caption, in our local database.
Here's an overview of how semantic search usually works (which is what we'll exactly implement).
We will use the following toolchain:
We simply let the user start and stop the recording by using a submit button in a form. This can of course be greatly refined by using Voice Detection. You may find an example here.
Firstly, we will:
- record an audio with MediaRecorder API.
- run a Speech-To-Text process to produce a text transcription from the audio.
We will use the pre-trained model openai/whisper-small
from https://huggingface.co
and use it with the help of the Bumblebee.Audio.speech_to_text_whisper function.
We get an Nx.Serving
that we will use to run this model with an input.
We then want to find images whose captions
approximate this text in terms of meaning.
This transcription is the "target text"
.
This is where embeddings come into play:
they are vector representations of certain inputs,
which in our case, are the text transcription of the audio file recorded by the user.
We encode each transcription as an embedding
and then use an approximation algorithm to find the closest neighbours.
Our next steps will be to prepare the symmetric semantic search. We will use a transformer model, more specifically the pre-trained sBert system available in Huggingface.
We transform a text into a vector with the sentence-transformer model sentence-transformers/paraphrase-MiniLM-L6-v2
.
Note
You may find models in the MTEB English leaderboard. We looked for "small" models in terms of file size and dimensions. You may want to try and use GTE small.
We will run the model with the help of the Bumblebee.Text.TextEmbedding.text_embedding function.
This encoding is done for each image caption.
At this point, we have:
- the embedding of the text transcription of the recording made by the user
(e.g
"a dog"
). - all the embeddings of all the images in our "image bank".
To search for the images that are related to "a dog"
,
we need to apply an algorithm that compares these embeddings!
For this, we will run a knn_neighbour search. There are several ways to do this.
-
we can use
pgvector
, a vector extension of Postgres. It is used to store vectors (the embeddings) and to run similarity searches. Withpgvector
, we can run:- a full exact search with the cosine similarity operator
<=>
, - or use an Approximate Nearest Neighbour seach with indexing algorithms. The extension proposes the
IVFFLAT
and theHNSWLIB
algorithms. You can find some explanations on both algorithms in https://tembo.io/blog/vector-indexes-in-pgvector and https://neon.tech/blog/understanding-vector-search-and-hnsw-index-with-pgvector.
- a full exact search with the cosine similarity operator
Note
Note that Supabase can use the pgvector
extension, and you can use Supabase with Fly.io.
Warning
Note that you need to save the embeddings (as vectors) into the database, so the database will be intensively used. This may lead to scaling problems and potential race conditions.
- we can alternatively use the
hnswlib
library and its Elixir binding HNSWLib. This "externalises" the ANN search from the database as it uses an in-memory file. This file needs to be persisted on disk, thus at the expense of using the filesystem with again potential race conditions. It works with an index struct: this struct will allow us to efficiently retrieve vector data.
We will use this last option,
mostly because we use Fly.io
and pgvector
is hard to come by on this platform.
We will use a GenServer to wrap all the calls to hnswlib
so every writes will be run synchronously.
Additionally, you don't rely on a framework that does the heavy lifting for you.
We're here to learn, aren't we? 😃
We will append incrementally the computed embedding from the captions into the Index.
We will get an indice which simply is the order of this embedding in the Index.
We then run a "knn_search" algorithm; the input will be the embedding of the audio transcript.
This algorithm will return the most relevant position(s) - indices
-
among the Index
indices that minimize the chosen distance between this input and the existing vectors.
This is where we'll need to save:
- whether the index,
- or the embedding
to look up for the corresponding image(s), depending upon if you append items one by one or by batch.
In our case, you will append items one by one so we will use the index to uniquely recover the nearest image whose caption is close semantically to our audio.
Do note that the measured distance is dependent on the similarity metric used by the embedding model. Because the "sentence-transformer" model we've chosen was trained with cosine_similarity, this is what we'll use. Bumblebee may have options to correctly use this metric, but we used a normalisation process which fits our needs.
We have already installed all the dependencies that we need.
[!WARNING] > You will also need to install
ffmpeg
. Indeed,Bumblebee
usesffmpeg
under the hood to process audio files into tensors, but it uses it as an external dependency.
And now we're ready to rock and roll! 🎸
We first need to capture the audio and upload it to the server.
The process is quite similar to the image upload, except that we use a special Javascript hook to record the audio and upload it to the Phoenix LiveView.
We use a live_file_input
in a form to capture the audio and use the Javascript MediaRecorder API
.
The Javascript code is triggered by an attached hook Audio
declared in the HTML.
We also let the user listen to his audio by adding an embedded audio element <audio>
in the HTML.
Its source is the audio blob as a URL object.
We also add a spinner to display that the transcription process is running,
in the same way as we did for the captioning process.
To avoid code duplication, we introduce a Phoenix component "Spinner".
Create the file
spinner.ex
in lib/app_web/components/
and create the Spinner
component,
like so:
# /lib/app_web/components/spinner.ex
defmodule AppWeb.Spinner do
use Phoenix.Component
attr :spin, :boolean, default: false
def spin(assigns) do
~H"""
<div :if={@spin} role="status">
<div class="relative w-6 h-6 animate-spin rounded-full bg-gradient-to-r from-purple-400 via-blue-500 to-red-400 ">
<div class="absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2 w-3 h-3 bg-gray-200 rounded-full border-2 border-white">
</div>
</div>
</div>
"""
end
end
In page_live_html.heex
, add the following snippet of code.
# page_live.html.heex
<form phx-change="noop">
<.live_file_input upload={@uploads.speech} class="hidden" />
<button
id="record"
class="bg-blue-500 hover:bg-blue-700 text-white font-bold px-4 rounded"
type="button"
phx-hook="Audio"
disabled="{@mic_off?}"
>
<Heroicons.microphone
outline
class="w-6 h-6 text-white font-bold group-active:animate-pulse"
/>
<span id="text">Record</span>
</button>
</form>
<p class="flex flex-col items-center">
<audio id="audio" controls></audio>
<AppWeb.Spinner.spin spin="{@audio_running?}" />
</p>
You can also use this component to display the spinner when the captioning task is running, so this part of your code will shrink to:
<!-- Spinner -->
<%= if @upload_running? do %>
<AppWeb.Spinner.spin spin={@upload_running?} />
<% else %>
<%= if @label do %>
<span class="text-gray-700 font-light"><%= @label %></span>
<% else %>
<span class="text-gray-300 font-light text-justify">Waiting for image input.</span>
<% end %>
<% end %>
Currently, we provide a basic user experience: we let the user click on a button to start and stop the recording.
When doing image captioning, we carefully worked on the size of the image file used for the captioning model to optimize the app's latency.
In the same spirit,
we can downsize the original audio file
so it's easier for the model to process it.
This will have the benefit of less overhead in our application.
Even though the whisper
model does downsize the audio's sampling rate,
we can do this on our client side to skip this step
and marginally reduce the overhead of running the model in our application.
The main parameters we're dealing with are:
- lower sampling rate (the higher the more accurate is the sound),
- use mono instead of stereo,
- and the file type (WAV, MP3)
Since most microphones on PC have a single channel (mono) and sample at 48kHz
,
we will focus on resampling to 16kHz
.
We will not make the conversion to mp3 here.
Next, we define the hook in a new JS file, located in the assets/js
folder.
To resample to 16kHz
, we use an AudioContext,
and pass the desired sampleRate
.
We then use the method decodeAudioData
which receives an AudioBuffer.
We get one from the Blob
method arrayBuffer().
The important part is the Phoenix.js
function upload
,
to which we pass an identifier "speech"
:
this sends the data as Blob
via a channel to the server.
We use an action button in the HTML,
and attach Javascript listeners to it on the "click"
, "dataavailable"
and "stop"
events.
We also play with the CSS classes to modify the appearance of the action button when recording or not.
Navigate to the assets
folder and run the following command.
We will use this to lower the sampling rate
and conver the recorded audio file.
npm add "audiobuffer-to-wav"
Create a file called assets/js/micro.js
and use the code below.
// /assets/js/micro.js
import toWav from "audiobuffer-to-wav";
export default {
mounted() {
let mediaRecorder,
audioChunks = [];
// Defining the elements and styles to be used during recording
// and shown on the HTML.
const recordButton = document.getElementById("record"),
audioElement = document.getElementById("audio"),
text = document.getElementById("text"),
blue = ["bg-blue-500", "hover:bg-blue-700"],
pulseGreen = ["bg-green-500", "hover:bg-green-700", "animate-pulse"];
_this = this;
// Adding event listener for "click" event
recordButton.addEventListener("click", () => {
// Check if it's recording.
// If it is, we stop the record and update the elements.
if (mediaRecorder && mediaRecorder.state === "recording") {
mediaRecorder.stop();
text.textContent = "Record";
}
// Otherwise, it means the user wants to start recording.
else {
navigator.mediaDevices.getUserMedia({ audio: true }).then((stream) => {
// Instantiate MediaRecorder
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.start();
// And update the elements
recordButton.classList.remove(...blue);
recordButton.classList.add(...pulseGreen);
text.textContent = "Stop";
// Add "dataavailable" event handler
mediaRecorder.addEventListener("dataavailable", (event) => {
audioChunks.push(event.data);
});
// Add "stop" event handler for when the recording stops.
mediaRecorder.addEventListener("stop", () => {
const audioBlob = new Blob(audioChunks);
// update the source of the Audio tag for the user to listen to his audio
audioElement.src = URL.createObjectURL(audioBlob);
// create an AudioContext with a sampleRate of 16000
const audioContext = new AudioContext({ sampleRate: 16000 });
// async read the Blob as ArrayBuffer to feed the "decodeAudioData"
const arrayBuffer = await audioBlob.arrayBuffer();
// decodes the ArrayBuffer into the AudioContext format
const audioBuffer = await audioContext.decodeAudioData(arrayBuffer);
// converts the AudioBuffer into a WAV format
const wavBuffer = toWav(audioBuffer);
// builds a Blob to pass to the Phoenix.JS.upload
const wavBlob = new Blob([wavBuffer], { type: "audio/wav" });
// upload to the server via a chanel with the built-in Phoenix.JS.upload
_this.upload("speech", [wavBlob]);
// close the MediaRecorder instance
mediaRecorder.stop();
// cleanups
audioChunks = [];
recordButton.classList.remove(...pulseGreen);
recordButton.classList.add(...blue);
});
});
}
});
},
};
Now let's import this file and declare our hook
object in our livesocket
object.
In our assets/js/app.js
file, let's do:
// /assets/js/app.js
...
import Audio from "./micro.js";
...
let liveSocket = new LiveSocket("/live", Socket, {
params: { _csrf_token: csrfToken },
hooks: { Audio },
});
We now need to add some server-side code.
The uploaded audio file will be saved on disk
as a temporary file in the /priv/static/uploads
folder.
We will also make this file unique every time a user records an audio.
We use an Ecto.UUID
string to the file name and pass it into the Liveview socket.
The Liveview mount/3
function returns a socket.
Let's update it
and pass extra arguments -
typically booleans for the UI such as the button disabling and the spinner -
as well as another allow_upload/3
to handle the upload process of the audio file.
In lib/app_web/live/page_live.ex
,
we change the code like so:
#page_live.ex
@upload_dir Application.app_dir(:app, ["priv", "static", "uploads"])
@tmp_wav Path.expand("priv/static/uploads/tmp.wav")
def mount(_,_,socket) do
socket
|> assign(
...,
transcription: nil,
mic_off?: false,
audio_running?: false,
tmp_wav: @tmp_wav
)
|> allow_upload(:speech,
accept: :any,
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1
)
|> allow_upload(:image_list, ...)
end
We then create a specific handle_progress
for the :speech
event
as we did with the :image_list
event.
It will launch a task to run the Automatic Speech Recognition model
on this audio file.
We named the serving "Whisper"
.
def handle_progress(:speech, entry, %{assigns: assigns} = socket) when entry.done? do
tmp_wav =
socket
|> consume_uploaded_entry(entry, fn %{path: path} ->
tmp_wav = assigns.tmp_wav <> Ecto.UUID.generate() <> ".wav"
:ok = File.cp!(path, tmp_wav)
{:ok, tmp_wav}
end)
audio_task =
Task.Supervisor.async(
App.TaskSupervisor,
fn ->
Nx.Serving.batched_run(Whisper, {:file, @tmp_wav})
end
)
{:noreply, socket
|> assign(
audio_ref: audio_task.ref,
mic_off?: true,
audio_running?: true,
tmp_wav: tmp_wav,
)}
end
And that's it for the Liveview portion!
Now that we are adding several models,
let's refactor our models.ex
module
that manages the models.
Since we're dealing with multiple models,
we want our app to shutdown if there's any problem loading them.
We now add the model Whisper
in the
lib/app/application.ex
so it's available throughout the application on runtime.
# lib/app/application.ex
def check_models_on_startup do
App.Models.verify_and_download_models()
|> case do
{:error, msg} ->
Logger.error("⚠️ #{msg}")
System.stop(0)
:ok ->
:ok
end
end
def start(_type, _args) do
end
# model check-up
:ok = check_models_on_startup()
children = [
...,
# Nx serving for Speech-to-Text
{Nx.Serving,
serving:
if Application.get_env(:app, :use_test_models) == true do
App.Models.audio_serving_test()
else
App.Models.audio_serving()
end,
name: Whisper},
,...
]
...
As you can see, we're using a serving similar
to the captioning model we've implemented earlier.
For this to work, we need to make some changes to the
models.ex
module.
Recall that this module simply manages the models that are
downloaded locally and used in our application.
To implement the functions above,
we change the lib/app/models.ex
module so it looks like so.
defmodule ModelInfo do
@moduledoc """
Information regarding the model being loaded.
It holds the name of the model repository and the directory it will be saved into.
It also has booleans to load each model parameter at will - this is because some models (like BLIP) require featurizer, tokenizations and generation configuration.
"""
defstruct [:name, :cache_path, :load_featurizer, :load_tokenizer, :load_generation_config]
end
defmodule App.Models do
@moduledoc """
Manages loading the modules and their location according to env.
"""
require Logger
# IMPORTANT: This should be the same directory as defined in the `Dockerfile`
# where the models will be downloaded into.
@models_folder_path Application.compile_env!(:app, :models_cache_dir)
# Embedding-------
@embedding_model %ModelInfo{
name: "sentence-transformers/paraphrase-MiniLM-L6-v2",
cache_path: Path.join(@models_folder_path, "paraphrase-MiniLM-L6-v2"),
load_featurizer: false,
load_tokenizer: true,
load_generation_config: true
}
# Captioning --
@captioning_test_model %ModelInfo{
name: "microsoft/resnet-50",
cache_path: Path.join(@models_folder_path, "resnet-50"),
load_featurizer: true
}
@captioning_prod_model %ModelInfo{
name: "Salesforce/blip-image-captioning-base",
cache_path: Path.join(@models_folder_path, "blip-image-captioning-base"),
load_featurizer: true,
load_tokenizer: true,
load_generation_config: true
}
# Audio transcription --
@audio_test_model %ModelInfo{
name: "openai/whisper-small",
cache_path: Path.join(@models_folder_path, "whisper-small"),
load_featurizer: true,
load_tokenizer: true,
load_generation_config: true
}
@audio_prod_model %ModelInfo{
name: "openai/whisper-small",
cache_path: Path.join(@models_folder_path, "whisper-small"),
load_featurizer: true,
load_tokenizer: true,
load_generation_config: true
}
def extract_captioning_test_label(result) do
%{predictions: [%{label: label}]} = result
label
end
def extract_captioning_prod_label(result) do
%{results: [%{text: label}]} = result
label
end
@doc """
Verifies and downloads the models according to configuration
and if they are already cached locally or not.
The models that are downloaded are hardcoded in this function.
"""
def verify_and_download_models() do
{
Application.get_env(:app, :force_models_download, false),
Application.get_env(:app, :use_test_models, false)
}
|> case do
{true, true} ->
# Delete any cached pre-existing models
File.rm_rf!(@models_folder_path)
with :ok <- download_model(@captioning_test_model),
:ok <- download_model(@embedding_model),
:ok <- download_model(@audio_test_model) do
:ok
else
{:error, msg} -> {:error, msg}
end
{true, false} ->
# Delete any cached pre-existing models
File.rm_rf!(@models_folder_path)
with :ok <- download_model(@captioning_prod_model),
:ok <- download_model(@audio_prod_model),
:ok <- download_model(@embedding_model) do
:ok
else
{:error, msg} -> {:error, msg}
end
{false, false} ->
# Check if the prod model cache directory exists or if it's not empty.
# If so, we download the prod models.
with :ok <- check_folder_and_download(@captioning_prod_model),
:ok <- check_folder_and_download(@audio_prod_model),
:ok <- check_folder_and_download(@embedding_model) do
:ok
else
{:error, msg} -> {:error, msg}
end
{false, true} ->
# Check if the test model cache directory exists or if it's not empty.
# If so, we download the test models.
with :ok <- check_folder_and_download(@captioning_test_model),
:ok <- check_folder_and_download(@audio_test_model),
:ok <- check_folder_and_download(@embedding_model) do
:ok
else
{:error, msg} -> {:error, msg}
end
end
end
@doc """
Serving function that serves the `Bumblebee` captioning model used throughout the app.
This function is meant to be called and served by `Nx` in `lib/app/application.ex`.
This assumes the models that are being used exist locally, in the @models_folder_path.
"""
def caption_serving do
load_offline_model(@captioning_prod_model)
|> then(fn response ->
case response do
{:ok, model} ->
%Nx.Serving{} =
Bumblebee.Vision.image_to_text(
model.model_info,
model.featurizer,
model.tokenizer,
model.generation_config,
compile: [batch_size: 1],
defn_options: [compiler: EXLA],
# needed to run on `Fly.io`
preallocate_params: true
)
{:error, msg} ->
{:error, msg}
end
end)
end
@doc """
Serving function that serves the `Bumblebee` audio transcription model used throughout the app.
"""
def audio_serving do
load_offline_model(@audio_prod_model)
|> then(fn response ->
case response do
{:ok, model} ->
%Nx.Serving{} =
Bumblebee.Audio.speech_to_text_whisper(
model.model_info,
model.featurizer,
model.tokenizer,
model.generation_config,
chunk_num_seconds: 30,
task: :transcribe,
defn_options: [compiler: EXLA],
preallocate_params: true
)
{:error, msg} ->
{:error, msg}
end
end)
end
@doc """
Serving function for tests only. It uses a test audio transcription model.
"""
def audio_serving_test do
load_offline_model(@audio_test_model)
|> then(fn response ->
case response do
{:ok, model} ->
%Nx.Serving{} =
Bumblebee.Audio.speech_to_text_whisper(
model.model_info,
model.featurizer,
model.tokenizer,
model.generation_config,
chunk_num_seconds: 30,
task: :transcribe,
defn_options: [compiler: EXLA],
preallocate_params: true
)
{:error, msg} ->
{:error, msg}
end
end)
end
@doc """
Serving function for tests only. It uses a test captioning model.
This function is meant to be called and served by `Nx` in `lib/app/application.ex`.
This assumes the models that are being used exist locally, in the @models_folder_path.
"""
def caption_serving_test do
load_offline_model(@captioning_test_model)
|> then(fn response ->
case response do
{:ok, model} ->
%Nx.Serving{} =
Bumblebee.Vision.image_classification(
model.model_info,
model.featurizer,
top_k: 1,
compile: [batch_size: 10],
defn_options: [compiler: EXLA],
# needed to run on `Fly.io`
preallocate_params: true
)
{:error, msg} ->
{:error, msg}
end
end)
end
# Loads the models from the cache folder.
# It will load the model and the respective the featurizer, tokenizer and generation config if needed,
# and return a map with all of these at the end.
@spec load_offline_model(map()) ::
{:ok, map()} | {:error, String.t()}
defp load_offline_model(model) do
Logger.info("ℹ️ Loading #{model.name}...")
# Loading model
loading_settings = {:hf, model.name, cache_dir: model.cache_path, offline: true}
Bumblebee.load_model(loading_settings)
|> case do
{:ok, model_info} ->
info = %{model_info: model_info}
# Load featurizer, tokenizer and generation config if needed
info =
if Map.get(model, :load_featurizer) do
{:ok, featurizer} = Bumblebee.load_featurizer(loading_settings)
Map.put(info, :featurizer, featurizer)
else
info
end
info =
if Map.get(model, :load_tokenizer) do
{:ok, tokenizer} = Bumblebee.load_tokenizer(loading_settings)
Map.put(info, :tokenizer, tokenizer)
else
info
end
info =
if Map.get(model, :load_generation_config) do
{:ok, generation_config} =
Bumblebee.load_generation_config(loading_settings)
Map.put(info, :generation_config, generation_config)
else
info
end
# Return a map with the model and respective parameters.
{:ok, info}
{:error, msg} ->
{:error, msg}
end
end
# Downloads the pre-trained models according to a given %ModelInfo struct.
# It will load the model and the respective the featurizer, tokenizer and generation config if needed.
@spec download_model(map()) :: {:ok, map()} | {:error, binary()}
defp download_model(model) do
Logger.info("ℹ️ Downloading #{model.name}...")
# Download model
downloading_settings = {:hf, model.name, cache_dir: model.cache_path}
# Download featurizer, tokenizer and generation config if needed
Bumblebee.load_model(downloading_settings)
|> case do
{:ok, _} ->
if Map.get(model, :load_featurizer) do
{:ok, _} = Bumblebee.load_featurizer(downloading_settings)
end
if Map.get(model, :load_tokenizer) do
{:ok, _} = Bumblebee.load_tokenizer(downloading_settings)
end
if Map.get(model, :load_generation_config) do
{:ok, _} = Bumblebee.load_generation_config(downloading_settings)
end
:ok
{:error, msg} ->
{:error, msg}
end
end
# Checks if the folder exists and downloads the model if it doesn't.
def check_folder_and_download(model) do
:ok = File.mkdir_p!(@models_folder_path)
model_location =
Path.join(model.cache_path, "huggingface")
if File.ls(model_location) == {:error, :enoent} or File.ls(model_location) == {:ok, []} do
download_model(model)
|> case do
:ok -> :ok
{:error, msg} -> {:error, msg}
end
else
Logger.info("ℹ️ No download needed: #{model.name}")
:ok
end
end
end
That's a lot! But we just need to focus on some new parts we've added:
- we've created
audio_serving_test/1
andaudio_serving/1
, our audio serving functions that are used in theapplication.ex
file. - added
@audio_prod_model
and@audio_test_model
, theWhisper
model definitions to be used to download the models locally. - refactored the image captioning model definitions to be more clear.
Now we're successfully serving audio-to-text capabilities in our application!
We expect the response of this task to be in the following form:
%{
chunks:
[%{
text: "Hi there",
#^^^the text of our audio
start_timestamp_seconds: nil,
end_timestamp_seconds: nil
}]
}
We capture this response in a handle_info
callback
where we simply prune the temporary audio file
and update the socket state with the result,
and update the booleans used for our UI
(the spinner element, the button availability and reset of the task once done).
def handle_info({ref, %{chunks: [%{text: text}]} = _result}, %{assigns: assigns} = socket)
when assigns.audio_ref == ref do
Process.demonitor(ref, [:flush])
File.rm!(assigns.tmp_wav)
{:noreply,
assign(socket,
transcription: String.trim(text),
mic_off?: false,
audio_running?: false,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
end
And that's it for this section! Our application is now able to record audio and transcribe it. 🎉
We want to encode every caption and the input text into an embedding which is a vector of a specific vector space. In other words, we encode a string into a list of numbers.
We chose the transformer "sentence-transformers/paraphrase-MiniLM-L6-v2"
model.
This transformer uses a 384
dimensional vector space.
Since this transformer is trained with a cosine metric
,
we embed the vector space of embeddings with the same distance.
You can read more about cosine_similarity here.
This model is loaded and served by an Nx.Serving
started in the Application module like all other models.
This library HNSWLib
works with an index.
We instantiate the Index file in a GenServer
which holds the index in the state.
We will use an Index file that is saved locally in our file system. This file will be updated any time we append an embedding; all the client calls and writes to the HNSWLib index are handled by the GenServer. They will happen synchronously. We want to minimize the race conditions in case several users interact with the app. This app is only meant to run on a single node.
It is started in the Application module (application.ex
).
When the app starts, we either read or create this file. The file is saved in the "/priv/static/uploads" folder.
Because we are deploying with Fly.io, we need to persist the Index file in the database because the machine - thus its attached volume - is pruned when inactive.
It is crucial to save the correspondence between the Image
table and the Index file to retrieve the correct images.
In simple terms, the file in the Index
table in the DB must correspond to the Index file in the system.
We therefore disable a user from loading several times the same file as otherwise, we would have several indexes for the same picture. This is done through SHA computation.
Since computations using models is a long-run process, and because several users may interact with the app, we need several steps to ensure that the information is synchronized between the database and the index file.
We also endow the vector space with a :cosine
pseudo-metric.
Add the following GenServer
file:
it will load the Index file,
and also provide a client API to interact with the Index,
which is held in the state of the GenServer.
Again, this solution works for a single node only.
defmodule App.KnnIndex do
use GenServer
@moduledoc """
A GenServer to load and handle the Index file for HNSWLib.
It loads the index from the FileSystem if existing or from the table HnswlibIndex.
It creates an new one if no Index file is found in the FileSystem
and if the table HnswlibIndex is empty.
It holds the index and the App.Image singleton table in the state.
"""
require Logger
@dim 384
@max_elements 200
@upload_dir Application.app_dir(:app, ["priv", "static", "uploads"])
@saved_index if Application.compile_env(:app, :knnindex_indices_test, false),
do: Path.join(@upload_dir, "indexes_test.bin"),
else: Path.join(@upload_dir, "indexes.bin")
# Client API ------------------
def start_link(args) do
:ok = File.mkdir_p!(@upload_dir)
GenServer.start_link(__MODULE__, args, name: __MODULE__)
end
def index_path do
@saved_index
end
def save_index_to_db do
GenServer.call(__MODULE__, :save_index_to_db)
end
def get_count do
GenServer.call(__MODULE__, :get_count)
end
def add_item(embedding) do
GenServer.call(__MODULE__, {:add_item, embedding})
end
def knn_search(input) do
GenServer.call(__MODULE__, {:knn_search, input})
end
def not_empty_index do
GenServer.call(__MODULE__, :not_empty)
end
# ---------------------------------------------------
@impl true
def init(args) do
# Trying to load the index file
index_path = Keyword.fetch!(args, :index)
space = Keyword.fetch!(args, :space)
case File.exists?(index_path) do
# If the index file doesn't exist, we try to load from the database.
false ->
{:ok, index, index_schema} =
App.HnswlibIndex.maybe_load_index_from_db(space, @dim, @max_elements)
{:ok, {index, index_schema, space}}
# If the index file does exist, we compare the one with teh table and check for incoherences.
true ->
Logger.info("ℹ️ Index file found on disk. Let's compare it with the database...")
App.Repo.get_by(App.HnswlibIndex, id: 1)
|> case do
nil ->
{:stop,
{:error,
"Error comparing the index file with the one on the database. Incoherence on table."}}
schema ->
check_integrity(index_path, schema, space)
end
end
end
defp check_integrity(path, schema, space) do
# We check the count of the images in the database and the one in the index.
with db_count <-
App.Repo.all(App.Image) |> length(),
{:ok, index} <-
HNSWLib.Index.load_index(space, @dim, path),
{:ok, index_count} <-
HNSWLib.Index.get_current_count(index),
true <-
index_count == db_count do
Logger.info("ℹ️ Integrity: ✅")
{:ok, {index, schema, space}}
# If it fails, we return an error.
else
false ->
{:stop,
{:error, "Integrity error. The count of images from index differs from the database."}}
{:error, msg} ->
Logger.error("⚠️ #{msg}")
{:stop, {:error, msg}}
end
end
@impl true
def handle_call(:save_index_to_db, _, {index, index_schema, space} = state) do
# We read the index file and try to update the index on the table as well.
File.read(@saved_index)
|> case do
{:ok, file} ->
{:ok, updated_schema} =
index_schema
|> App.HnswlibIndex.changeset(%{file: file})
|> App.Repo.update()
{:reply, {:ok, updated_schema}, {index, updated_schema, space}}
{:error, msg} ->
{:reply, {:error, msg}, state}
end
end
def handle_call(:get_count, _, {index, _, _} = state) do
{:ok, count} = HNSWLib.Index.get_current_count(index)
{:reply, count, state}
end
def handle_call({:add_item, embedding}, _, {index, _, _} = state) do
# We add the new item to the index and update it.
with :ok <-
HNSWLib.Index.add_items(index, embedding),
{:ok, idx} <-
HNSWLib.Index.get_current_count(index),
:ok <-
HNSWLib.Index.save_index(index, @saved_index) do
{:reply, {:ok, idx}, state}
else
{:error, msg} ->
{:reply, {:error, msg}, state}
end
end
def handle_call({:knn_search, nil}, _, state) do
{:reply, {:error, "No index found"}, state}
end
def handle_call({:knn_search, input}, _, {index, _, _} = state) do
# We search for the nearest neighbors of the input embedding.
case HNSWLib.Index.knn_query(index, input, k: 1) do
{:ok, labels, _distances} ->
response =
labels[0]
|> Nx.to_flat_list()
|> hd()
|> then(fn idx ->
App.Repo.get_by(App.Image, %{idx: idx + 1})
end)
# TODO: add threshold on "distances"
{:reply, response, state}
{:error, msg} ->
{:reply, {:error, msg}, state}
end
end
def handle_call(:not_empty, _, {index, _, _} = state) do
case HNSWLib.Index.get_current_count(index) do
{:ok, 0} ->
Logger.warning("⚠️ Empty index.")
{:reply, :error, state}
{:ok, _} ->
{:reply, :ok, state}
end
end
end
Let's unpack a bit of what we are doing here.
-
we first are defining the module constants. Here, we add the dimensions of the embedding vector space (these are dependent on the model you choose). Check with the model you've used to tweak this settings optimally.
-
define the upload directory where the index file will be saved inside the filesystem.
-
when the GenServer is initialized (
init/1
function), we perform several integrity verifications, checking if both theIndex
file in the filesystem and the file in theIndex
table (from now on, this table will be calledHnswlibIndex
, under the name of the same schema). These validations essentially make sure the content of both files are the same. -
the other functions provide a basic API for callers to add items to the index file, so it is saved.
As you may have seen from the previous GenServer,
we are calling functions from a module called
App.HnswlibIndex
that we have not yet created.
This module pertains to the schema that will hold
information of the HNSWLib
table.
This table will only have a single row,
with the file contents.
As we've discussed earlier,
we will compare the Index file in this row
with the one in the filesystem
to check for any inconsistencies that may arise.
Let's implement this module now!
Inside lib/app
, create a file called hnswlib_index.ex
and use the following code.
defmodule App.HnswlibIndex do
use Ecto.Schema
alias App.HnswlibIndex
require Logger
@moduledoc """
Ecto schema to save the HNSWLib Index file into a singleton table
with utility functions
"""
schema "hnswlib_index" do
field(:file, :binary)
field(:lock_version, :integer, default: 1)
end
def changeset(struct \\ %__MODULE__{}, params \\ %{}) do
struct
|> Ecto.Changeset.cast(params, [:id, :file])
|> Ecto.Changeset.optimistic_lock(:lock_version)
|> Ecto.Changeset.validate_required([:id])
end
@doc """
Tries to load index from DB.
If the table is empty, it creates a new one.
If the table is not empty but there's no file, an index is created from scratch.
If there's one, we use it and load it to be used throughout the application.
"""
def maybe_load_index_from_db(space, dim, max_elements) do
# Check if the table has an entry
App.Repo.get_by(HnswlibIndex, id: 1)
|> case do
# If the table is empty
nil ->
Logger.info("ℹ️ No index file found in DB. Creating new one...")
create(space, dim, max_elements)
# If the table is not empty but has no file
response when response.file == nil ->
Logger.info("ℹ️ Empty index file in DB. Recreating one...")
# Purge the table and create a new file row in it
App.Repo.delete_all(App.HnswlibIndex)
create(space, dim, max_elements)
# If the table is not empty and has a file
index_db ->
Logger.info("ℹ️ Index file found in DB. Loading it...")
# We get the path of the index
with path <- App.KnnIndex.index_path(),
# Save the file on disk
:ok <- File.write(path, index_db.file),
# And load it
{:ok, index} <- HNSWLib.Index.load_index(space, dim, path) do
{:ok, index, index_db}
end
end
end
defp create(space, dim, max_elements) do
# Inserting the row in the table
{:ok, schema} =
HnswlibIndex.changeset(%__MODULE__{}, %{id: 1})
|> App.Repo.insert()
# Creates index
{:ok, index} =
HNSWLib.Index.new(space, dim, max_elements)
# Builds index for testing only
if Application.get_env(:app, :use_test_models, false) do
empty_index =
Application.app_dir(:app, ["priv", "static", "uploads"])
|> Path.join("indexes_empty.bin")
HNSWLib.Index.save_index(index, empty_index)
end
{:ok, index, schema}
end
end
In this module:
-
we are creating two fields:
lock_version
, to simply check the version of the file; andfile
, the binary content of the index file. -
lock_version
will be extremely useful to perform optmistic locking, which is what we do in thechangeset/2
function. This will allow us to prevent deadlocking when two different people upload the same image at the same time, and overcome any race condition that may occur. This will maintain the data consistency in the Index file. -
maybe_load_index_from_db/3
fetches the singleton row on this table and checks if the file exists in the row. If it doesn't, it creates a new one. Otherwise, it just loads the existing one inside the row. -
create/3
creates a new index file. It's a private function that encapsulates creating the Index file so it can be used in the singleton row inside the table.
And that's it! We've added additional code to conditionally create different indexes according to the environment (useful for testing), but you can safely ignore those conditional calls if you're not interested in testing (though you should 😛).
We provide a serving for the embedding model in the App.Models
module.
It should look like this:
#App.Models
@embedding_model %ModelInfo{
name: "sentence-transformers/paraphrase-MiniLM-L6-v2",
cache_path: Path.join(@models_folder_path, "paraphrase-MiniLM-L6-v2"),
load_featurizer: false,
load_tokenizer: true,
load_generation_config: true
}
def embedding() do
load_offline_model(@embedding_model)
|> then(fn response ->
case response do
{:ok, model} ->
# return n %Nx.Serving{} struct
%Nx.Serving{} =
Bumblebee.Text.TextEmbedding.text_embedding(
model.model_info,
model.tokenizer,
defn_options: [compiler: EXLA],
preallocate_params: true
)
{:error, msg} ->
{:error, msg}
end
end)
end
def verify_and_download_models() do
force_models_download = Application.get_env(:app, :force_models_download, false)
use_test_models = Application.get_env(:app, :use_test_models, false)
case {force_models_download, use_test_models} do
{true, true} ->
File.rm_rf!(@models_folder_path)
download_model(@captioning_test_model)
download_model(@audio_test_model)
{true, false} ->
File.rm_rf!(@models_folder_path)
download_model(@embedding_model)
^^^
download_model(@captioning_prod_model)
download_model(@audio_prod_model)
{false, false} ->
check_folder_and_download(@embedding_model)
^^
check_folder_and_download(@captioning_prod_model)
check_folder_and_download(@audio_prod_model)
{false, true} ->
check_folder_and_download(@captioning_test_model)
check_folder_and_download(@audio_test_model)
end
end
You then add the Nx.Serving
for the embeddings:
#application.ex
children = [
...,
{Nx.Serving,
serving: App.Models.embedding(),
name: Embedding,
batch_size: 5
},
...
]
Your application.ex
file should look like so:
defmodule App.Application do
# See https://hexdocs.pm/elixir/Application.html
# for more information on OTP Applications
@moduledoc false
require Logger
use Application
@upload_dir Application.app_dir(:app, ["priv", "static", "uploads"])
@saved_index if Application.compile_env(:app, :knnindex_indices_test, false),
do: Path.join(@upload_dir, "indexes_test.bin"),
else: Path.join(@upload_dir, "indexes.bin")
def check_models_on_startup do
App.Models.verify_and_download_models()
|> case do
{:error, msg} ->
Logger.error("⚠️ #{msg}")
System.stop(0)
:ok ->
Logger.info("ℹ️ Models: ✅")
:ok
end
end
@impl true
def start(_type, _args) do
:ok = check_models_on_startup()
children = [
# Start the Telemetry supervisor
AppWeb.Telemetry,
# Setup DB
App.Repo,
# Start the PubSub system
{Phoenix.PubSub, name: App.PubSub},
# Nx serving for the embedding
{Nx.Serving, serving: App.Models.embedding(), name: Embedding, batch_size: 1},
# Nx serving for Speech-to-Text
{Nx.Serving,
serving:
if Application.get_env(:app, :use_test_models) == true do
App.Models.audio_serving_test()
else
App.Models.audio_serving()
end,
name: Whisper},
# Nx serving for image classifier
{Nx.Serving,
serving:
if Application.get_env(:app, :use_test_models) == true do
App.Models.caption_serving_test()
else
App.Models.caption_serving()
end,
name: ImageClassifier},
{GenMagic.Server, name: :gen_magic},
# Adding a supervisor
{Task.Supervisor, name: App.TaskSupervisor},
# Start the Endpoint (http/https)
AppWeb.Endpoint
# Start a worker by calling: App.Worker.start_link(arg)
# {App.Worker, arg}
]
# We are starting the HNSWLib Index GenServer only during testing.
# Because this GenServer needs the database to be seeded first,
# we only add it when we're not testing.
# When testing, you need to spawn this process manually (it is done in the test_helper.exs file).
children =
if Application.get_env(:app, :start_genserver, true) == true do
Enum.concat(children, [{App.KnnIndex, [space: :cosine, index: @saved_index]}])
else
children
end
# See https://hexdocs.pm/elixir/Supervisor.html
# for other strategies and supported options
opts = [strategy: :one_for_one, name: App.Supervisor]
Supervisor.start_link(children, opts)
end
# Tell Phoenix to update the endpoint configuration
# whenever the application is updated.
@impl true
def config_change(changed, _new, removed) do
AppWeb.Endpoint.config_change(changed, removed)
:ok
end
end
Note
We have added a few alterations to how the supervision tree
in application.ex
is initialized.
This is because we test our code,
so that's why you see some of these changes above.
If you don't want to change test the code, you can ignore the conditional changes that are made to the supervision tree according to the environment (which we do to check if the code is being tested or not).
In this section, we'll go over how to use the Index and the embeddings and tie everything together to have a working application 😍.
If you want to better understand embeddings and
how to use HNSWLib
,
the math behind it and see a working example
of running an embedding model,
you can check the next section.
However, it is entirely optional
and not necessary for our app.
For a working example on how to use the index in hnswlib
,
you can run the ".exs" file there.
@tmp_wav Path.expand("priv/static/uploads/tmp.wav")
def mount(_, _, socket) do
{:ok,
socket
|> assign(
...,
# Related to the Audio
transcription: nil,
mic_off?: false,
audio_running?: false,
audio_search_result: nil,
tmp_wav: @tmp_wav,
)
|> allow_upload(:speech,...)
[...]
}
end
Recall that every time you upload an image,
you get back a URL from our bucket
and you compute a caption as a string.
We will now compute an embedding from this string
and save it in the Index.
This is done in the handle_info
callback.
Update the Liveview handle_info
callback where we handle the captioning results:
def handle_info({ref, result}, %{assigns: assigns} = socket) do
# Flush async call
Process.demonitor(ref, [:flush])
cond do
# If the upload task has finished executing,
# we update the socket assigns.
Map.get(assigns, :task_ref) == ref ->
image =
%{
url: assigns.image_info.url,
width: assigns.image_info.width,
height: assigns.image_info.height,
description: label
}
with %{embedding: data} <-
Nx.Serving.batched_run(Embedding, label),
# compute a normed embedding (cosine case only) on the text result
normed_data <-
Nx.divide(data, Nx.LinAlg.norm(data)),
{:check_used, {:ok, pending_image}} <-
{:check_used, App.Image.check_before_append_to_index(image.sha1)} do
{:ok, idx} =
App.KnnIndex.add_item(normed_data) do
# save the App.Image to the DB
Map.merge(image, %{idx: idx, caption: label})
|> App.Image.insert()
{:noreply,
socket
|> assign(
upload_running?: false,
task_ref: nil,
label: label
)
}
else
{:error, msg} ->
{:noreply,
socket
|> put_flash(:error, msg)
|> assign(
upload_running?: false,
task_ref: nil,
label: nil
)
}
end
[...]
end
end
Every time we produce an audio file, we transcribe it into a text.
We then compute the embedding of the audio input transcription and run an ANN search.
The last step should return a (possibly) populated %App.Image{}
struct with a look-up in the database.
We then update the "audio_search_result"
assign with it and display the transcription.
Modify the following handler:
def handle_info({ref, %{chunks: [%{text: text}]} = result}, %{assigns: assigns} = socket)
when assigns.audio_ref == ref do
Process.demonitor(ref, [:flush])
File.rm!(@tmp_wav)
# compute an normed embedding (cosine case only) on the text result
# and returns an App.Image{} as the result of a "knn_search"
with %{embedding: input_embedding} <-
Nx.Serving.batched_run(Embedding, text),
normed_input_embedding <-
Nx.divide(input_embedding, Nx.LinAlg.norm(input_embedding)),
{:not_empty_index, :ok} <-
{:not_empty_index, App.KnnIndex.not_empty_index()},
# {:not_empty_index, App.HnswlibIndex.not_empty_index(index)},
%App.Image{} = result <-
App.KnnIndex.knn_search(normed_input_embedding) do
{:noreply,
assign(socket,
transcription: String.trim(text),
mic_off?: false,
audio_running?: false,
audio_search_result: result,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
# record without entries
{:not_empty_index, :error} ->
{:noreply,
assign(socket,
mic_off?: false,
audio_search_result: nil,
audio_running?: false,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
nil ->
{:noreply,
assign(socket,
transcription: String.trim(text),
mic_off?: false,
audio_search_result: nil,
audio_running?: false,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
end
end
We next come back to the knn_search
function we defined in the "KnnIndex" GenServer.
The "approximate nearest neighbour" search function uses the function HNSWLib.Index.knn_query/3
.
It returns a tuple {:ok, indices, distances}
where "indices" and "distances" are lists.
The length is the number of neighbours you want to find parametrized by the k
parameter.
With k=1
, we ask for a single neighbour.
Note
You may further use a cut-off distance to exclude responses that might not be meaningful.
We will now display the found image with the URL field of the %App.Image{}
struct.
Add this to "page_live.html.heex"
:
<!-- /lib/app_Web/live/page_live.html.heex -->
<div :if="{@audio_search_result}">
<img src="{@audio_search_result.url}" alt="found_image" />
</div>
Now we'll save the index found.
Let's add a column to the Image
table.
To do this, run a mix
task to generate a timestamped file.
mix ecto.gen.migration add_idx_to_images
In the "/priv/repo"
folder, open the newly created file and add:
defmodule App.Repo.Migrations.AddIdxToImages do
use Ecto.Migration
def change do
alter table(:images) do
add(:idx, :integer, default: 0)
add(:sha1, :string)
end
end
end
and run the migration
by running mix ecto.migrate
.
Modify the App.Image
struct and the changeset:
@primary_key {:id, :id, autogenerate: true}
schema "images" do
field(:description, :string)
field(:width, :integer)
field(:url, :string)
field(:height, :integer)
field(:idx, :integer)
field(:sha1, :string)
timestamps(type: :utc_datetime)
end
def changeset(image, params \\ %{}) do
image
|> Ecto.Changeset.cast(params, [:url, :description, :width, :height, :idx, :sha1])
|> Ecto.Changeset.validate_required([:width, :height])
|> Ecto.Changeset.unique_constraint(:sha1, name: :images_sha1_index)
|> Ecto.Changeset.unique_constraint(:idx, name: :images_idx_index)
end
We've added the fields idx
and sha1
to the image schema.
The former pertains to the index of the image
within the HNSWLIB
index file,
so we can look for the image.
The latter pertains to the sha1
representation of the image.
This will allow us to check if two images are the same,
so we can avoid adding duplicate images
and save some throughput in our application.
In our changeset/2
function,
we've fundamentally added two unique_constraint/3
functions
to check for the uniqueness of the newly added
idx
and sha1
function.
These are enforced at the database level so we don't have
duplicated images.
In addition to these changes,
we are going to need functions to
calculate the sha1
of the image.
Add the following functions to the same file.
def calc_sha1(file_binary) do
:crypto.hash(:sha, file_binary)
|> Base.encode16()
end
def check_sha1(sha1) when is_binary(sha1) do
App.Repo.get_by(App.Image, %{sha1: sha1})
|> case do
nil ->
nil
%App.Image{} = image ->
{:ok, image}
end
end
calc_sha1/1
uses the:crypto
package to hash the file binary and encode it.check_sha1/1
fetches an image according to a givensha1
code and returns the result.
And that's all we need to deal with our images!
Now we have
- all the embedding models ready to be used,
- our Index file correctly created and maintained through
filesystem and in the database in the
hnswlib_index
schema, - the needed
sha1
functions to check for duplicated images.
It's time to bring everything together and use all of these tools to implement semantic search into our application.
We are going to be working inside lib/app_web/live/page_live.ex
from now on.
First, we are going to update our socket assigns on mount/3
.
@image_width 640
@accepted_mime ~w(image/jpeg image/jpg image/png image/webp)
@tmp_wav Path.expand("priv/static/uploads/tmp.wav")
@impl true
def mount(_params, _session, socket) do
{:ok,
socket
|> assign(
# Related to the file uploaded by the user
label: nil,
upload_running?: false,
task_ref: nil,
image_info: nil,
image_preview_base64: nil,
# Related to the list of image examples
example_list_tasks: [],
example_list: [],
display_list?: false,
# Related to the Audio
transcription: nil,
mic_off?: false,
audio_running?: false,
audio_search_result: nil,
tmp_wav: @tmp_wav
)
|> allow_upload(:image_list,
accept: ~w(image/*),
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1,
chunk_size: 64_000,
max_file_size: 5_000_000
)
|> allow_upload(:speech,
accept: :any,
auto_upload: true,
progress: &handle_progress/3,
max_entries: 1
)}
end
To reiterate:
-
we've added a few fields related to audio.
transcription
will pertain to the result of the audio transcription that will occur after transcribing the audio from the person.mic_off?
is simply a toggle to visually show the person whether the microphone is recording or not.audio_running?
is a boolean to show the person if the audio transcription and semantic search are occurring (loading).audio_search_result
is the result of the image that is closest semantically to the image's label from the transcribed audio.tmp_wav
is the path of the temporary audio file that is saved in the filesystem while the audio is being transcribed.
-
additionally, we also have added
allow_upload/3
pertaining to the audio upload (it is tagged as:speech
and is being handled in the same function as the upload:image_list
).
These are the socket assigns that will allow us to dynamically update the person using our app with what the app is doing.
As you can see, we are using handle_progress/3
with allow_upload/3
.
As we know, handle_progress/3
is called whenever an upload happens
(whether an image or recording of the person's voice).
We define two different declarations for how we want to
process :image
uploads and :speech
uploads.
Let's start with the first one.
We have added sha1
and idx
as fields to our image schema.
Therefore, we are going to need to make some changes
to the handle_progress/3
of the :image_list
.
Change it like so:
def handle_progress(:image_list, entry, socket) when entry.done? do
# We consume the entry only if the entry is done uploading from the image
# and if consuming the entry was successful.
consume_uploaded_entry(socket, entry, fn %{path: path} ->
with {:magic, {:ok, %{mime_type: mime}}} <- {:magic, magic_check(path)},
# Check if file can be properly read
{:read, {:ok, file_binary}} <- {:read, File.read(path)},
# Check the image info
{:image_info, {mimetype, width, height, _variant}} <-
{:image_info, ExImageInfo.info(file_binary)},
# Check mime type
{:check_mime, :ok} <- {:check_mime, check_mime(mime, mimetype)},
# Get SHA1 code from the image and check it
sha1 = App.Image.calc_sha1(file_binary),
{:sha_check, nil} <- {:sha_check, App.Image.check_sha1(sha1)},
# Get image and resize
{:ok, thumbnail_vimage} <- Vops.thumbnail(path, @image_width, size: :VIPS_SIZE_DOWN),
# Pre-process the image as tensor
{:pre_process, {:ok, tensor}} <- {:pre_process, pre_process_image(thumbnail_vimage)} do
# Create image info to be saved as partial image
image_info = %{
mimetype: mimetype,
width: width,
height: height,
sha1: sha1,
description: nil,
url: nil,
# set a random big int to the "idx" field
idx: :rand.uniform(1_000_000_000_000) * 1_000
}
# Save partial image
App.Image.insert(image_info)
|> case do
{:ok, _} ->
image_info =
Map.merge(image_info, %{
file_binary: file_binary
})
{:ok, %{tensor: tensor, image_info: image_info, path: path}}
{:error, changeset} ->
{:error, changeset.errors}
end
|> handle_upload()
else
{:magic, {:error, msg}} -> {:postpone, %{error: msg}}
{:read, msg} -> {:postpone, %{error: inspect(msg)}}
{:image_info, nil} -> {:postpone, %{error: "image_info error"}}
{:check_mime, :error} -> {:postpone, %{error: "Bad mime type"}}
{:sha_check, {:ok, %App.Image{}}} -> {:postpone, %{error: "Image already uploaded"}}
{:pre_process, {:error, _msg}} -> {:postpone, %{error: "pre_processing error"}}
{:error, reason} -> {:postpone, %{error: inspect(reason)}}
end
end)
|> case do
# If consuming the entry was successful, we spawn a task to classify the image
# and update the socket assigns
%{tensor: tensor, image_info: image_info} ->
task =
Task.Supervisor.async(App.TaskSupervisor, fn ->
Nx.Serving.batched_run(ImageClassifier, tensor)
end)
# Encode the image to base64
base64 = "data:image/png;base64, " <> Base.encode64(image_info.file_binary)
{:noreply,
assign(socket,
upload_running?: true,
task_ref: task.ref,
image_preview_base64: base64,
image_info: image_info
)}
# Otherwise, if there was an error uploading the image, we log the error and show it to the person.
%{error: error} ->
Logger.warning("⚠️ Error uploading image. #{inspect(error)}")
{:noreply, push_event(socket, "toast", %{message: "Image couldn't be uploaded to S3.\n#{error}"})}
end
end
Let's go over these changes. Some of these is code that has been written prior, but for clarification, we'll go over them again.
-
we use
consume_uploaded_entry/3
to consume the image that the person uploads. To consume the image successfully, the image goes through an array of validations.- we use
magic_check/1
to check the MIME type of the image validity. - we read the contents of the image using
ExImageInfo.info/1
. - we check if the MIME type is valid using
check_mime/2
. - we calculate the
sha1
with theApp.Image.calc_sha1/1
function we've developed earlier. - we resize the image and scale it down to the same width
as the images that are trained using the image captioning model we've chosen
(to yield better results and to save memory bandwidth).
We use
Vix.Operations.thumbnail/3
to resize the image. - finally, we convert the resized image to a tensor using
pre_process_image/1
so it can be consumed by our image captioning model.
- we use
-
after this series of validations, we use the image info we've obtained earlier to create an "early-save" of the image. With this, we are saving the image and associating it with the
sha1
that was retrieved from the image contents. We are doing this "partial image saving" in case two identical images are being uploaded at the same time. Because we are enforcingsha1
to be unique at the database level, this race condition is solved by the database optimistically. -
afterwards, we call
handle_upload/0
. This function will upload the image to theS3
bucket. We are going to implement this function in just a second 😉. -
if the upload is successful, using the tensor and the image information from the previous steps, we spawn the async task to run the model. This step should be familiar to you since we've already implemented this. Finally, we update the socket assigns accordingly.
-
we handle all possible errors in the
else
statement of thewith
flow control statement before the image is uploaded.
Hopefully, this demystifies some of the code we've just implemented!
Because we are using handle_upload/0
in this function
to upload the image to our S3
bucket,
let's do it right now!
def handle_upload({:ok, %{path: path, tensor: tensor, image_info: image_info} = map})
when is_map(map) do
# Upload the image to S3
Image.upload_image_to_s3(path, image_info.mimetype)
|> case do
# If the upload is successful, we update the socket assigns with the image info
{:ok, url} ->
image_info =
struct(
%ImageInfo{},
Map.merge(image_info, %{url: url})
)
{:ok, %{tensor: tensor, image_info: image_info}}
# If S3 upload fails, we return error
{:error, reason} ->
Logger.warning("⚠️ Error uploading image: #{inspect(reason)}")
{:postpone, %{error: "Bucket error"}}
end
end
def handle_upload({:error, error}) do
Logger.warning("⚠️ Error creating partial image: #{inspect(error)}")
{:postpone, %{error: "Error creating partial image"}}
end
This function is fairly easy to understand.
We upload the image by calling Image.upload_image_to_s3/2
and,
if successful,
we add the returning URL to the image struct.
Otherwise, we handle the error and return it.
After this small detour,
let's implement the handle_progress/3
for the :speech
uploads,
that is, the audio the person records.
def handle_progress(:speech, entry, %{assigns: assigns} = socket) when entry.done? do
# We consume the audio file
tmp_wav =
socket
|> consume_uploaded_entry(entry, fn %{path: path} ->
tmp_wav = assigns.tmp_wav <> Ecto.UUID.generate() <> ".wav"
:ok = File.cp!(path, tmp_wav)
{:ok, tmp_wav}
end)
# After consuming the audio file, we spawn a task to transcribe the audio
audio_task =
Task.Supervisor.async(
App.TaskSupervisor,
fn ->
Nx.Serving.batched_run(Whisper, {:file, tmp_wav})
end
)
# Update the socket assigns
{:noreply,
assign(socket,
audio_ref: audio_task.ref,
mic_off?: true,
tmp_wav: tmp_wav,
audio_running?: true,
audio_search_result: nil,
transcription: nil
)}
end
As we know, this function is called after the upload is completed.
In the case of audio uploads,
the hook is called by the person recording their voice
in assets/js/app.js
.
Similarly to the handle_progress/3
function of the :image_list
uploads,
we also use consume_uploaded_entry/3
to consume the audio file.
- we consume the audio file and save it in our filesystem
as a
.wav
file. - we spawn the async task and use the
whisper
audio transcription model with the audio file we've just saved. - we update the socket assigns accordingly.
Pretty simple, right?
In this section, we'll finally use our embedding model and semantically search for our images!
As you've seen in the previous section,
we've spawned the task to transcribe the audio into the whipser
model.
Now we need a handler!
For this scenario,
add the following function.
@impl true
def handle_info({ref, %{chunks: [%{text: text}]} = _result}, %{assigns: assigns} = socket)
when assigns.audio_ref == ref do
Process.demonitor(ref, [:flush])
File.rm!(assigns.tmp_wav)
# Compute an normed embedding (cosine case only) on the text result
# and returns an App.Image{} as the result of a "knn_search"
with {:not_empty_index, :ok} <-
{:not_empty_index, App.KnnIndex.not_empty_index()},
%{embedding: input_embedding} <-
Nx.Serving.batched_run(Embedding, text),
%Nx.Tensor{} = normed_input_embedding <-
Nx.divide(input_embedding, Nx.LinAlg.norm(input_embedding)),
%App.Image{} = result <-
App.KnnIndex.knn_search(normed_input_embedding) do
{:noreply,
assign(socket,
transcription: String.trim(text),
mic_off?: false,
audio_running?: false,
audio_search_result: result,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
else
# Stop transcription if no entries in the Index
{:not_empty_index, :error} ->
{:noreply,
socket
|> push_event("toast", %{message: "No images yet"})
|> assign(
mic_off?: false,
transcription: "!! The image bank is empty. Please upload some !!",
audio_search_result: nil,
audio_running?: false,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
nil ->
{:noreply,
assign(socket,
transcription: String.trim(text),
mic_off?: false,
audio_search_result: nil,
audio_running?: false,
audio_ref: nil,
tmp_wav: @tmp_wav
)}
end
end
Let's break down this function:
- given the recording text transcription:
- we check if the Index file holding is not empty.
- we use the text transcription and run it through the embedding model and get its result.
- with the embedding we've received from the model, we normalize it.
- with the normalized embedding,
we **run it through a
knn search
. For this, we call theApp.KnnIndex.knn_search/1
function we've defined in theApp.KnnIndex
GenServer we've implemented earlier on. - the
knn search
returns the closest semantical image (through the image caption) from the audio transcription. - upon the success of this process, we update the socket assigns.
- otherwise, we handle each error case accordingly and update the socket assigns.
And that's it! We just add to sequentially call the functions that we've implemented prior!
Now that we have used the embeddings, there's one thing we forgot: we forgot to keep track of the embeddings of each image that is uploaded. These embeddings are saved in the Index file.
To fix this, we need to create an embedding of the image
after it is uploaded and captioned.
Head over to the handle_info/2
pertaining to the image captioning,
and change it to the following piece of code:
def handle_info({ref, result}, %{assigns: assigns} = socket) do
# Flush async call
Process.demonitor(ref, [:flush])
# You need to change how you destructure the output of the model depending
# on the model you've chosen for `prod` and `test` envs on `models.ex`.)
label =
case Application.get_env(:app, :use_test_models, false) do
true ->
App.Models.extract_captioning_test_label(result)
# coveralls-ignore-start
false ->
App.Models.extract_captioning_prod_label(result)
# coveralls-ignore-stop
end
%{image_info: image_info} = assigns
cond do
# If the upload task has finished executing, we run the embedding model on the image
Map.get(assigns, :task_ref) == ref ->
image =
%{
url: image_info.url,
width: image_info.width,
height: image_info.height,
description: label,
sha1: image_info.sha1
}
# Create embedding task
with %{embedding: data} <- Nx.Serving.batched_run(Embedding, label),
# Compute a normed embedding (cosine case only) on the text result
normed_data <- Nx.divide(data, Nx.LinAlg.norm(data)),
# Check the SHA1 of the image
{:check_used, {:ok, pending_image}} <-
{:check_used, App.Image.check_sha1(image.sha1)} do
Ecto.Multi.new()
# Save updated Image to DB
|> Ecto.Multi.run(:update_image, fn _, _ ->
idx = App.KnnIndex.get_count() + 1
Ecto.Changeset.change(pending_image, %{
idx: idx,
description: image.description,
url: image.url
})
|> App.Repo.update()
end)
# Save Index file to DB
|> Ecto.Multi.run(:save_index, fn _, _ ->
{:ok, _idx} = App.KnnIndex.add_item(normed_data)
App.KnnIndex.save_index_to_db()
end)
|> App.Repo.transaction()
|> case do
{:error, :update_image, _changeset, _} ->
{:noreply,
socket
|> push_event("toast", %{message: "Invalid entry"})
|> assign(
upload_running?: false,
task_ref: nil,
label: nil
)}
{:error, :save_index, _, _} ->
{:noreply,
socket
|> push_event("toast", %{message: "Please retry"})
|> assign(
upload_running?: false,
task_ref: nil,
label: nil
)}
{:ok, _} ->
{:noreply,
socket
|> assign(
upload_running?: false,
task_ref: nil,
label: label
)}
end
else
{:check_used, nil} ->
{:noreply,
socket
|> push_event("toast", %{message: "Race condition"})
|> assign(
upload_running?: false,
task_ref: nil,
label: nil
)}
{:error, msg} ->
{:noreply,
socket
|> push_event("toast", %{message: msg})
|> assign(
upload_running?: false,
task_ref: nil,
label: nil
)}
end
# If the example task has finished executing, we upload the socket assigns.
img = Map.get(assigns, :example_list_tasks) |> Enum.find(&(&1.ref == ref)) ->
# Update the element in the `example_list` enum to turn "predicting?" to `false`
updated_example_list = update_example_list(assigns, img, label)
{:noreply,
assign(socket,
example_list: updated_example_list,
upload_running?: false,
display_list?: true
)}
end
end
Let's go over the flow of this function:
- we extract the captioning label from the result of the image captioning model. This code is the same as it was before.
- afterwards, we get the label and feed it into the embedding model.
- the embedding model yields the embedding,
we normalize it and check if the
sha1
code of the image is already being used. - if these three processes occur successfuly, we perform a database transaction where we save the updated image to the database, update the Index file count (we increment it) and save the index file to the database.
- we finally update the socket assigns accordingly.
- if any of the previous calls fail, we handle these error scenarios and update the socket assigns.
And that's it! Our app is fully loaded with semantic search capabilities! 🔋
All that's left is updating our view. We are going to add basic elements to make this transition as smooth as possible.
Head over to lib/app_web/live/page_live.html.heex
and update it as so:
<div class="hidden" id="tracker_el" phx-hook="ActivityTracker" />
<div
class="h-full w-full px-4 py-10 flex justify-center sm:px-6 sm:py-24 lg:px-8 xl:px-28 xl:py-32"
>
<div class="flex flex-col justify-start">
<div class="flex justify-center items-center w-full">
<div class="2xl:space-y-12">
<div class="mx-auto max-w-2xl lg:text-center">
<p>
<span
class="rounded-full w-fit bg-brand/5 px-2 py-1 text-[0.8125rem] font-medium text-center leading-6 text-brand"
>
<a
href="https://hexdocs.pm/phoenix_live_view/Phoenix.LiveView.html"
target="_blank"
rel="noopener noreferrer"
>
🔥 LiveView
</a>
+
<a
href="https://github.com/elixir-nx/bumblebee"
target="_blank"
rel="noopener noreferrer"
>
🐝 Bumblebee
</a>
</span>
</p>
<p
class="mt-2 text-3xl font-bold tracking-tight text-gray-900 sm:text-4xl"
>
Caption your image!
</p>
<h3 class="mt-6 text-lg leading-8 text-gray-600">
Upload your own image (up to 5MB) and perform image captioning with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
Elixir
</a>
!
</h3>
<p class="text-lg leading-8 text-gray-400">
Powered with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
HuggingFace🤗
</a>
transformer models, you can run this project locally and perform
machine learning tasks with a handful lines of code.
</p>
</div>
<div></div>
<div class="border-gray-900/10">
<!-- File upload section -->
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
phx-drop-target="{@uploads.image_list.ref}"
>
<div class="text-center">
<!-- Show image preview -->
<%= if @image_preview_base64 do %>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<%= if not @upload_running? do %>
<.live_file_input upload={@uploads.image_list} class="hidden" />
<% end %>
<img src="{@image_preview_base64}" />
</label>
</form>
<% else %>
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list}
class="hidden" /> Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">
PNG, JPG, GIF up to 5MB
</p>
<% end %>
</div>
</div>
</div>
</div>
<!-- Show errors -->
<%= for entry <- @uploads.image_list.entries do %>
<div class="mt-2">
<%= for err <- upload_errors(@uploads.image_list, entry) do %>
<div class="rounded-md bg-red-50 p-4 mb-2">
<div class="flex">
<div class="flex-shrink-0">
<svg
class="h-5 w-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.28 7.22a.75.75 0 00-1.06 1.06L8.94 10l-1.72 1.72a.75.75 0 101.06 1.06L10 11.06l1.72 1.72a.75.75 0 101.06-1.06L11.06 10l1.72-1.72a.75.75 0 00-1.06-1.06L10 8.94 8.28 7.22z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class="ml-3">
<h3 class="text-sm font-medium text-red-800">
<%= error_to_string(err) %>
</h3>
</div>
</div>
</div>
<% end %>
</div>
<% end %>
<!-- Prediction text -->
<div
class="flex mt-2 space-x-1.5 items-center font-bold text-gray-900 text-xl"
>
<span>Description: </span>
<!-- conditional Spinner or display caption text or waiting text-->
<AppWeb.Spinner.spin spin="{@upload_running?}" />
<%= if @label do %>
<span class="text-gray-700 font-light"><%= @label %></span>
<% else %>
<span class="text-gray-300 font-light">Waiting for image input.</span>
<% end %>
</div>
</div>
</div>
<!-- Audio -->
<br />
<div class="mx-auto max-w-2xl lg">
<h2
class="mt-2 text-3xl font-bold tracking-tight text-gray-900 sm:text-4xl text-center"
>
Semantic search using an audio
</h2>
<br />
<p>
Please record a phrase. You can listen to your audio. It will be
transcripted automatically into a text and appear below. The semantic
search for matching images will then run automatically and the found
image appear below.
</p>
<br />
<form
id="audio-upload-form"
phx-change="noop"
class="flex flex-col items-center"
>
<.live_file_input upload={@uploads.speech} class="hidden" />
<button
id="record"
class="bg-blue-500 hover:bg-blue-700 text-white font-bold px-4 rounded flex"
type="button"
phx-hook="Audio"
disabled="{@mic_off?}"
>
<Heroicons.microphone
outline
class="w-6 h-6 text-white font-bold group-active:animate-pulse"
/>
<span id="text">Record</span>
</button>
</form>
<br />
<p class="flex flex-col items-center">
<audio id="audio" controls></audio>
</p>
<br />
<div
class="flex mt-2 space-x-1.5 items-center font-bold text-gray-900 text-xl"
>
<span>Transcription: </span>
<%= if @audio_running? do %>
<AppWeb.Spinner.spin spin="{@audio_running?}" />
<% else %>
<%= if @transcription do %>
<span class="text-gray-700 font-light"><%= @transcription %></span>
<% else %>
<span class="text-gray-300 font-light text-justify">Waiting for audio input.</span>
<% end %>
<% end %>
</div>
<br />
<div :if="{@audio_search_result}">
<div class="border-gray-900/10">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
>
<img src="{@audio_search_result.url}" alt="found_image" />
</div>
</div>
</div>
</div>
<!-- Examples -->
<div :if="{@display_list?}" class="flex flex-col">
<h3
class="mt-10 text-xl lg:text-center font-light tracking-tight text-gray-900 lg:text-2xl"
>
Examples
</h3>
<div class="flex flex-row justify-center my-8">
<div
class="mx-auto grid max-w-2xl grid-cols-1 gap-x-6 gap-y-20 sm:grid-cols-2"
>
<%= for example_img <- @example_list do %>
<!-- Loading skeleton if it is predicting -->
<%= if example_img.predicting? == true do %>
<div
role="status"
class="flex items-center justify-center w-full h-full max-w-sm bg-gray-300 rounded-lg animate-pulse"
>
<img src={~p"/images/spinner.svg"} alt="spinner" />
<span class="sr-only">Loading...</span>
</div>
<% else %>
<div>
<img
id="{example_img.url}"
src="{example_img.url}"
class="rounded-2xl object-cover"
/>
<h3 class="mt-1 text-lg leading-8 text-gray-900 text-center">
<%= example_img.label %>
</h3>
</div>
<% end %>
</div>
</div>
</div>
</div>
</div>
As you may have noticed, we've made some changes to the Audio portion of the HTML.
- we check if the
@transcription
assign exists. If so, we display the text to the person. - we check if the
@audio_search_result
assign is notnil
. If that's the case, the image that is semantically closest to the audio transcription is shown to the person.
And that's it! We are simply showing the person the results.
And with that, you've successfully added semantic search into the application! Pat yourself on the back! 👏
You've expanded your knowledge in key areas of machine learning and artificial intelligence, that is increasingly becoming more prevalent!
Now that we have all the features we want in our application, let's make it prettier! As it stands, it's responsive enough. But we can always make it better!
We're going to show you the changes you're going to need to make and then explain it to you what it means!
Head over to lib/app_web/live/page_live.html.heex
and change it like so:
<div class="hidden" id="tracker_el" phx-hook="ActivityTracker" />
<div class="h-full w-full px-4 py-10 flex justify-center sm:px-6 xl:px-28">
<div class="flex flex-col justify-start lg:w-full">
<div class="flex justify-center items-center w-full">
<div class="w-full 2xl:space-y-12">
<div class="mx-auto lg:text-center">
<!-- Title pill -->
<p class="text-center">
<span class="rounded-full w-fit bg-brand/5 px-2 py-1 text-[0.8125rem] font-medium text-center leading-6 text-brand">
<a
href="https://hexdocs.pm/phoenix_live_view/Phoenix.LiveView.html"
target="_blank"
rel="noopener noreferrer"
>
🔥 LiveView
</a>
+
<a
href="https://github.com/elixir-nx/bumblebee"
target="_blank"
rel="noopener noreferrer"
>
🐝 Bumblebee
</a>
</span>
</p>
<!-- Toggle Buttons -->
<div class="flex justify-center lg:invisible">
<span class="isolate inline-flex rounded-md shadow-sm mt-2">
<button id="upload_option" type="button" class="relative inline-flex items-center gap-x-1.5 rounded-l-md bg-blue-500 text-white hover:bg-blue-600 px-3 py-2 text-sm font-semibold ring-1 ring-inset ring-gray-300 focus:z-10">
<svg fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="-ml-0.5 h-5 w-5 text-white">
<path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 0 0 5.25 21h13.5A2.25 2.25 0 0 0 21 18.75V16.5m-13.5-9L12 3m0 0 4.5 4.5M12 3v13.5" />
</svg>
Upload
</button>
<button id="search_option" type="button" class="relative -ml-px inline-flex items-center rounded-r-md bg-white px-3 py-2 text-sm font-semibold text-gray-900 ring-1 ring-inset ring-gray-300 hover:bg-gray-50 focus:z-10">
<svg fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="-ml-0.5 h-5 w-5 text-gray-400">
<path stroke-linecap="round" stroke-linejoin="round" d="m21 21-5.197-5.197m0 0A7.5 7.5 0 1 0 5.196 5.196a7.5 7.5 0 0 0 10.607 10.607Z" />
</svg>
Search
</button>
</span>
</div>
<!-- Containers -->
<div class="flex flex-col lg:flex-row lg:justify-around">
<!-- UPLOAD CONTAINER -->
<div id="upload_container" class="mb-6 lg:px-10">
<p class="mt-2 text-center text-3xl font-bold tracking-tight text-gray-900 sm:text-4xl">
Caption your image!
</p>
<div class="flex gap-x-4 rounded-xl bg-black/5 px-6 py-2 mt-2">
<div class="flex flex-col justify-center items-center">
<svg fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-7 h-7 text-indigo-400">
<path stroke-linecap="round" stroke-linejoin="round" d="M9 8.25H7.5a2.25 2.25 0 0 0-2.25 2.25v9a2.25 2.25 0 0 0 2.25 2.25h9a2.25 2.25 0 0 0 2.25-2.25v-9a2.25 2.25 0 0 0-2.25-2.25H15m0-3-3-3m0 0-3 3m3-3V15" />
</svg>
</div>
<div class="text-sm leading-2 text-justify flex flex-col justify-center">
<p class="text-slate-700">
Upload your own image (up to 5MB) and perform image captioning with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
Elixir
</a>
!
</p>
</div>
</div>
<p class="mt-4 text-center text-sm leading-2 text-gray-400">
Powered with
<a
href="https://elixir-lang.org/"
target="_blank"
rel="noopener noreferrer"
class="font-mono font-medium text-sky-500"
>
HuggingFace🤗
</a>
transformer models,
you can run this project locally and perform machine learning tasks with a handful lines of code.
</p>
<!-- File upload section -->
<div class="border-gray-900/10 mt-4">
<div class="col-span-full">
<div
class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10"
phx-drop-target={@uploads.image_list.ref}
>
<div class="text-center">
<!-- Show image preview -->
<%= if @image_preview_base64 do %>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<%= if not @upload_running? do %>
<.live_file_input upload={@uploads.image_list} class="hidden" />
<% end %>
<img src={@image_preview_base64} />
</label>
</form>
<% else %>
<svg
class="mx-auto h-12 w-12 text-gray-300"
viewBox="0 0 24 24"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M1.5 6a2.25 2.25 0 012.25-2.25h16.5A2.25 2.25 0 0122.5 6v12a2.25 2.25 0 01-2.25 2.25H3.75A2.25 2.25 0 011.5 18V6zM3 16.06V18c0 .414.336.75.75.75h16.5A.75.75 0 0021 18v-1.94l-2.69-2.689a1.5 1.5 0 00-2.12 0l-.88.879.97.97a.75.75 0 11-1.06 1.06l-5.16-5.159a1.5 1.5 0 00-2.12 0L3 16.061zm10.125-7.81a1.125 1.125 0 112.25 0 1.125 1.125 0 01-2.25 0z"
clip-rule="evenodd"
/>
</svg>
<div class="mt-4 flex text-sm leading-6 text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer rounded-md bg-white font-semibold text-indigo-600 focus-within:outline-none focus-within:ring-2 focus-within:ring-indigo-600 focus-within:ring-offset-2 hover:text-indigo-500"
>
<form id="upload-form" phx-change="noop" phx-submit="noop">
<label class="cursor-pointer">
<.live_file_input upload={@uploads.image_list} class="hidden" /> Upload
</label>
</form>
</label>
<p class="pl-1">or drag and drop</p>
</div>
<p class="text-xs leading-5 text-gray-600">PNG, JPG, GIF up to 5MB</p>
<% end %>
</div>
</div>
</div>
</div>
<!-- Show errors -->
<%= for entry <- @uploads.image_list.entries do %>
<div class="mt-2">
<%= for err <- upload_errors(@uploads.image_list, entry) do %>
<div class="rounded-md bg-red-50 p-4 mb-2">
<div class="flex">
<div class="flex-shrink-0">
<svg
class="h-5 w-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
aria-hidden="true"
>
<path
fill-rule="evenodd"
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.28 7.22a.75.75 0 00-1.06 1.06L8.94 10l-1.72 1.72a.75.75 0 101.06 1.06L10 11.06l1.72 1.72a.75.75 0 101.06-1.06L11.06 10l1.72-1.72a.75.75 0 00-1.06-1.06L10 8.94 8.28 7.22z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class="ml-3">
<h3 class="text-sm font-medium text-red-800">
<%= error_to_string(err) %>
</h3>
</div>
</div>
</div>
<% end %>
</div>
<% end %>
<!-- Prediction text -->
<div class="flex mt-2 space-x-1.5 items-center">
<span class="font-bold text-gray-900">Description: </span>
<!-- conditional Spinner or display caption text or waiting text-->
<%= if @upload_running? do %>
<AppWeb.Spinner.spin spin={@upload_running?} />
<% else %>
<%= if @label do %>
<span class="text-gray-700 font-light"><%= @label %></span>
<% else %>
<span class="text-gray-300 font-light text-justify">Waiting for image input.</span>
<% end %>
<% end %>
</div>
<!-- Examples -->
<%= if @display_list? do %>
<div :if={@display_list?} class="mt-16 flex flex-col">
<h3 class="text-xl text-center font-bold tracking-tight text-gray-900 lg:text-2xl">
Examples
</h3>
<div class="flex flex-row justify-center my-8">
<div class="mx-auto grid max-w-2xl grid-cols-1 gap-x-6 gap-y-10 sm:grid-cols-2">
<%= for example_img <- @example_list do %>
<!-- Loading skeleton if it is predicting -->
<%= if example_img.predicting? == true do %>
<div
role="status"
class="flex items-center justify-center w-full h-full max-w-sm bg-gray-300 rounded-lg animate-pulse"
>
<img src={~p"/images/spinner.svg"} alt="spinner" />
<span class="sr-only">Loading...</span>
</div>
<% else %>
<div>
<img id={example_img.url} src={example_img.url} class="rounded-2xl object-cover" />
<h3 class="mt-1 text-lg leading-8 text-gray-900 text-center">
<%= example_img.label %>
</h3>
</div>
<% end %>
<% end %>
</div>
</div>
</div>
<% end %>
</div>
<!-- AUDIO SEMANTIC SEARCH CONTAINER -->
<div id="search_container" class="hidden mb-6 mx-auto lg:block lg:px-10">
<h2 class="mt-2 text-3xl font-bold tracking-tight text-gray-900 sm:text-4xl text-center">
...or search it!
</h2>
<div class="flex gap-x-4 rounded-xl bg-black/5 px-6 py-2 mt-2">
<div class="flex flex-col justify-center items-center">
<svg fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-7 h-7 text-indigo-400">
<path stroke-linecap="round" stroke-linejoin="round" d="M12 18.75a6 6 0 0 0 6-6v-1.5m-6 7.5a6 6 0 0 1-6-6v-1.5m6 7.5v3.75m-3.75 0h7.5M12 15.75a3 3 0 0 1-3-3V4.5a3 3 0 1 1 6 0v8.25a3 3 0 0 1-3 3Z" />
</svg>
</div>
<div class="text-sm leading-2 text-justify flex flex-col justify-center">
<p class="text-slate-700">
Record a phrase or some key words.
We'll detect them and semantically search it in our database of images!
</p>
</div>
</div>
<p class="mt-4 text-center text-sm leading-2 text-gray-400">
After recording your audio, you can listen to it. It will be transcripted automatically into text and appear below.
</p>
<p class="text-center text-sm leading-2 text-gray-400">
Semantic search will automatically kick in and the resulting image will be shown below.
</p>
<!-- Audio recording button -->
<form id="audio-upload-form" phx-change="noop" class="mt-8 flex flex-col items-center">
<.live_file_input upload={@uploads.speech} class="hidden" />
<button
id="record"
class="bg-blue-500 hover:bg-blue-700 text-white font-bold p-4 rounded flex"
type="button"
phx-hook="Audio"
disabled={@mic_off?}
>
<Heroicons.microphone
outline
class="w-6 h-6 text-white font-bold group-active:animate-pulse"
/>
<span id="text">Record</span>
</button>
</form>
<!-- Audio preview -->
<p class="flex flex-col items-center mt-6">
<audio id="audio" controls></audio>
</p>
<!-- Audio transcription -->
<div class="flex mt-2 space-x-1.5 items-center">
<span class="font-bold text-gray-900">Transcription: </span>
<%= if @audio_running? do %>
<AppWeb.Spinner.spin spin={@audio_running?} />
<% else %>
<%= if @transcription do %>
<span id="output" class="text-gray-700 font-light"><%= @transcription %></span>
<% else %>
<span class="text-gray-300 font-light text-justify">Waiting for audio input.</span>
<% end %>
<% end %>
</div>
<!-- Semantic search result -->
<div :if={@audio_search_result}>
<div class="border-gray-900/10">
<div class="mt-2 flex justify-center rounded-lg border border-dashed border-gray-900/25 px-6 py-10">
<img src={@audio_search_result.url} alt="found_image" />
</div>
</div>
<span class="text-gray-700 font-light"><%= @audio_search_result.description %></span>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
That may look like a lot, but we've done just a handful of changes!
-
we've reestructured our HTML so it's easier to read. You just have a few key elements: the pill on top of the page, a toggle that we've added (to show the upload section and the search section) and two containers with the upload and search section, respectively. The code is practically intact for each section.
-
we've made minor styling changes to each section. On the upload section, we added a small callout section. In the search section, we added a small calout section as well, and added the description of the image that is found after the audio transcription occurs.
And that's it!
What's important
is that only one second is shown at a time on mobile devices
and both sections are shown on desktop devices
(over 1024
pixels - pertaining to the lg
breakpoint of TailwindCSS
).
On desktop devices, the toggle button should disappear.
To accomplish these, we need to add a bit of Javascript
and CSS
magic. 🪄
Head over to assets/js/app.js
and add the following code.
document.getElementById("upload_option").addEventListener("click", function () {
document.getElementById("upload_container").style.display = "block";
document.getElementById("search_container").style.display = "none";
document
.getElementById("upload_option")
.classList.replace("bg-white", "bg-blue-500");
document
.getElementById("upload_option")
.classList.replace("text-gray-900", "text-white");
document
.getElementById("upload_option")
.classList.replace("hover:bg-gray-50", "hover:bg-blue-600");
document
.getElementById("upload_option")
.getElementsByTagName("svg")[0]
.classList.replace("text-gray-400", "text-white");
document
.getElementById("search_option")
.classList.replace("bg-blue-500", "bg-white");
document
.getElementById("search_option")
.classList.replace("text-white", "text-gray-900");
document
.getElementById("search_option")
.classList.replace("hover:bg-blue-600", "hover:bg-gray-50");
document
.getElementById("search_option")
.getElementsByTagName("svg")[0]
.classList.replace("text-white", "text-gray-400");
});
document.getElementById("search_option").addEventListener("click", function () {
document.getElementById("upload_container").style.display = "none";
document.getElementById("search_container").style.display = "block";
document
.getElementById("search_option")
.classList.replace("bg-white", "bg-blue-500");
document
.getElementById("search_option")
.classList.replace("text-gray-900", "text-white");
document
.getElementById("search_option")
.classList.replace("hover:bg-gray-50", "hover:bg-blue-600");
document
.getElementById("search_option")
.getElementsByTagName("svg")[0]
.classList.replace("text-gray-400", "text-white");
document
.getElementById("upload_option")
.classList.replace("bg-blue-500", "bg-white");
document
.getElementById("upload_option")
.classList.replace("text-white", "text-gray-900");
document
.getElementById("upload_option")
.classList.replace("hover:bg-blue-600", "hover:bg-gray-50");
document
.getElementById("upload_option")
.getElementsByTagName("svg")[0]
.classList.replace("text-white", "text-gray-400");
});
The code is self-explanatory. We are changing the styles of the toggle buttons according to the button that is clicked.
The other thing we need to make sure is that
both sections are shown on desktop devices,
regardless of what the current section is selected.
Luckily, we can override styles by adding this piece of code
to assets/css/app.css
.
@media (min-width: 1024px) {
#upload_container,
#search_container {
display: block !important; /* Override any inline styles */
}
}
And that's it!
We can see our slightly refactored UI in all of its glory
by running mix phx.server
!
If you find this package/repo useful, please star it on GitHub, so that we know! ⭐
Thank you! 🙏