# Unsplash Joint Query Search

Using this notebook you can search for images from the [Unsplash Dataset](https://unsplash.com/data) using natural language queries. The search is powered by OpenAI's [CLIP](https://github.com/openai/CLIP) neural network.

This notebook uses the precomputed feature vectors for almost 2 million images from the full version of the [Unsplash Dataset](https://unsplash.com/data). If you want to compute the features yourself, see [here](https://github.com/haltakov/natural-language-image-search#on-your-machine).

This project was mostly based on the [project](https://github.com/haltakov/natural-language-image-search) created by [Vladimir Haltakov](https://twitter.com/haltakov) and the full code is open-sourced on [GitHub](https://github.com/haofanwang/natural-language-joint-query-search).

In [1]:
!git clone https://github.com/haofanwang/natural-language-joint-query-search.git

Cloning into 'natural-language-joint-query-search'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (106/106), done.[K
remote: Total 116 (delta 37), reused 43 (delta 5), pack-reused 0[K
Receiving objects: 100% (116/116), 13.12 MiB | 29.52 MiB/s, done.
Resolving deltas: 100% (37/37), done.


In [2]:
cd natural-language-joint-query-search

/content/natural-language-joint-query-search


## Setup Environment

In this section we will setup the environment.

First we need to install CLIP and then upgrade the version of torch to 1.7.1 with CUDA support (by default CLIP installs torch 1.7.1 without CUDA). Google Colab currently has torch 1.7.0 which doesn't work well with CLIP.

In [3]:
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install ftfy regex tqdm

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (735.4MB)
[K     |████████████████████████████████| 735.4MB 24kB/s 
[?25hCollecting torchvision==0.8.2+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.8.2%2Bcu101-cp36-cp36m-linux_x86_64.whl (12.8MB)
[K     |████████████████████████████████| 12.8MB 114kB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.7.1+cu101 torchvision-0.8.2+cu101
Collecting ftfy
[?25l  Downloading https://files.pythonhosted.org/packages/04/06/e5c80e2e0f979628d47

## Download the Precomputed Data

In this section the precomputed feature vectors for all photos are downloaded.

In order to compare the photos from the Unsplash dataset to a text query, we need to compute the feature vector of each photo using CLIP. 

We need to download two files:
* `photo_ids.csv` - a list of the photo IDs for all images in the dataset. The photo ID can be used to get the actual photo from Unsplash.
* `features.npy` - a matrix containing the precomputed 512 element feature vector for each photo in the dataset.

The files are available on [Google Drive](https://drive.google.com/drive/folders/1WQmedVCDIQKA2R33dkS1f980YsJXRZ-q?usp=sharing).

In [4]:
from pathlib import Path

# Create a folder for the precomputed features
!mkdir unsplash-dataset

# Download the photo IDs and the feature vectors
!gdown --id 1FdmDEzBQCf3OxqY9SbU-jLfH_yZ6UPSj -O unsplash-dataset/photo_ids.csv
!gdown --id 1L7ulhn4VeN-2aOM-fYmljza_TQok-j9F -O unsplash-dataset/features.npy

# Download from alternative source, if the download doesn't work for some reason (for example download quota limit exceeded)
if not Path('unsplash-dataset/photo_ids.csv').exists():
  !wget https://transfer.army/api/download/TuWWFTe2spg/EDm6KBjc -O unsplash-dataset/photo_ids.csv

if not Path('unsplash-dataset/features.npy').exists():
  !wget https://transfer.army/api/download/LGXAaiNnMLA/AamL9PpU -O unsplash-dataset/features.npy

Downloading...
From: https://drive.google.com/uc?id=1FdmDEzBQCf3OxqY9SbU-jLfH_yZ6UPSj
To: /content/natural-language-joint-query-search/unsplash-dataset/photo_ids.csv
23.8MB [00:00, 111MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1L7ulhn4VeN-2aOM-fYmljza_TQok-j9F
To: /content/natural-language-joint-query-search/unsplash-dataset/features.npy
2.03GB [00:40, 50.3MB/s]


## Define Functions

Some important functions from CLIP for processing the data are defined here.

The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.

In [10]:
def encode_search_query(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded, weight = model.encode_text(clip.tokenize(search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector from the GPU and convert it to a numpy array
        return text_encoded.cpu().numpy()

The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching photos.

In [6]:
def find_best_matches(text_features, photo_features, photo_ids, results_count=3):
  # Compute the similarity between the search query and each photo using the Cosine similarity
  similarities = (photo_features @ text_features.T).squeeze(1)

  # Sort the photos by their similarity score
  best_photo_idx = (-similarities).argsort()

  # Return the photo IDs of the best matches
  return [photo_ids[i] for i in best_photo_idx[:results_count]]

We can load the pretrained public CLIP model.

In [7]:
import torch

from CLIP.clip import clip

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

100%|████████████████████████████████████████| 354M/354M [00:02<00:00, 138MiB/s]


We can now load the pre-extracted unsplash image features.



In [8]:
import pandas as pd
import numpy as np

# Load the photo IDs
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

# Load the features vectors
photo_features = np.load("unsplash-dataset/features.npy")

# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")

Photos loaded: 1981161


## Search Unsplash



Now we are ready to search the dataset using natural language. Check out the examples below and feel free to try out your own queries.

In this project, we support more types of searching than the [original project](https://github.com/haltakov/natural-language-image-search).

1. Text-to-Image Search
2. Image-to-Image Search
3. Text+Text-to-Image Search
4. Image+Text-to-Image Search

Note: 

1. As the Unsplash API limit is hit from time to time, we don't display the image, but show the link to download the image.
2. As the pretrained CLIP model is mainly trained with English texts, if you want to try with different language, please use Google translation API or NMT model to translate first.

### Text-to-Image Search

#### "Tokyo Tower at night"

In [11]:
search_query = "Tokyo Tower at night."

text_features = encode_search_query(search_query)

# Find the best matches
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/Hfjoa3qqytM/download
https://unsplash.com/photos/9tOyu48-P7M/download
https://unsplash.com/photos/OCgMGflYgVg/download
https://unsplash.com/photos/msYlh78QagI/download
https://unsplash.com/photos/UYmsWq6Cf1c/download


#### "Two children are playing in the amusement park."

In [12]:
search_query = "Two children are playing in the amusement park."

text_features = encode_search_query(search_query)

# Find the best matches
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/VPq1DiHNShY/download
https://unsplash.com/photos/nQlKkqq6qEw/download
https://unsplash.com/photos/lgXRsUVWl88/download
https://unsplash.com/photos/b10qqhvwWg4/download
https://unsplash.com/photos/xUDUhI_qsKQ/download


### Image-to-Image Search

In [13]:
from PIL import Image

source_image = "./images/borna-hrzina-8IPrifbjo-0-unsplash.jpg"
with torch.no_grad():
  image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
  image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()

# Find the best matches
best_photo_ids = find_best_matches(image_feature, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/8IPrifbjo-0/download
https://unsplash.com/photos/2Hzzw1qfVTQ/download
https://unsplash.com/photos/q1gXY48Ej78/download
https://unsplash.com/photos/OYaw40WnhSc/download
https://unsplash.com/photos/DpeXitxtix8/download


### Text+Text-to-Image Search

In [14]:
search_query = "red flower"
search_query_extra = "blue sky"

text_features = encode_search_query(search_query)
text_features_extra = encode_search_query(search_query_extra)

mixed_features = text_features + text_features_extra

# Find the best matches
best_photo_ids = find_best_matches(mixed_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/NewdN4HJaWM/download
https://unsplash.com/photos/r6DXsecvS4w/download
https://unsplash.com/photos/Ye-PdCxCmEQ/download
https://unsplash.com/photos/AFT4cSrnVZk/download
https://unsplash.com/photos/qKBVUBtZJCU/download


### Image+Text-to-Image Search

In [16]:
source_image = "./images/borna-hrzina-8IPrifbjo-0-unsplash.jpg"
search_text = "cars"

with torch.no_grad():
  image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
  image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()

text_feature = encode_search_query(search_text)

# image + text
modified_feature = image_feature + text_feature

best_photo_ids = find_best_matches(modified_feature, photo_features, photo_ids, 5)
    
for photo_id in best_photo_ids:
      print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/8IPrifbjo-0/download
https://unsplash.com/photos/2Hzzw1qfVTQ/download
https://unsplash.com/photos/6FpUtZtjFjM/download
https://unsplash.com/photos/Qm8pvpJ-uGs/download
https://unsplash.com/photos/c3ddbxzQtdM/download
