### Load a model from a description using embeddings (with a datastore to optimize multiple searches)
##### by: Ben Gibbons (https://bgibbonsweb.github.io/)

To run this notebook, you'll need both a SketchFab Api Token & an OpenAI key

In [1]:
freeze_model_database_models = True
maxModelFaces = 5 * 1000
minModelFaces = 1 * 1000
min_image_size = 350
ai_model = "gpt-4o-mini"
model_download_path = "./downloaded_models"

modelToDownload = "Silver car"
model_name = "ViT-H/14-378"
pretrained = "dfn5b"

# model_name = "ViT-B/32"
# pretrained = "dfn5b"

In [2]:
# !pip install numpy==1.26.4 requests==2.32.3 torch==2.5.1 pillow==10.3.0 open_clip_torch==2.30.0 openai pillow==10.3.0 torchvision==0.20.1

In [3]:
import gc
import json
import os
from io import BytesIO

import numpy as np
import requests
import torch
from PIL import Image

import clip
import open_clip
from openai import OpenAI
import zipfile
from getpass import getpass

open_ai_key = getpass("Enter your OpenAI API key: ")
sketchfab_api_token = getpass("Enter your Sketchfab API token: ")

In [4]:

available_models = open_clip.list_models()
for model in available_models:
    print(model)

coca_base
coca_roberta-ViT-B-32
coca_ViT-B-32
coca_ViT-L-14
convnext_base
convnext_base_w
convnext_base_w_320
convnext_large
convnext_large_d
convnext_large_d_320
convnext_small
convnext_tiny
convnext_xlarge
convnext_xxlarge
convnext_xxlarge_320
EVA01-g-14
EVA01-g-14-plus
EVA02-B-16
EVA02-E-14
EVA02-E-14-plus
EVA02-L-14
EVA02-L-14-336
MobileCLIP-B
MobileCLIP-S1
MobileCLIP-S2
mt5-base-ViT-B-32
mt5-xl-ViT-H-14
nllb-clip-base
nllb-clip-base-siglip
nllb-clip-large
nllb-clip-large-siglip
RN50
RN50-quickgelu
RN50x4
RN50x4-quickgelu
RN50x16
RN50x16-quickgelu
RN50x64
RN50x64-quickgelu
RN101
RN101-quickgelu
roberta-ViT-B-32
swin_base_patch4_window7_224
ViT-B-16
ViT-B-16-plus
ViT-B-16-plus-240
ViT-B-16-quickgelu
ViT-B-16-SigLIP
ViT-B-16-SigLIP-256
ViT-B-16-SigLIP-384
ViT-B-16-SigLIP-512
ViT-B-16-SigLIP-i18n-256
ViT-B-32
ViT-B-32-256
ViT-B-32-plus-256
ViT-B-32-quickgelu
ViT-bigG-14
ViT-bigG-14-CLIPA
ViT-bigG-14-CLIPA-336
ViT-bigG-14-quickgelu
ViT-e-14
ViT-g-14
ViT-H-14
ViT-H-14-378
ViT-H-14-378-q

In [5]:
client = OpenAI(api_key=open_ai_key)
if not os.path.exists(model_download_path):
    os.mkdir(model_download_path)

In [6]:
def askMultiQuestion(questions, the_model = ai_model):
    print("askMultiQuestion", the_model, questions[0])
    messages = []
    messages.append({ "role": "system", "content": "You are a helpful assistant." })
    ai_resp = ""

    while len(questions) > 0:
        question = questions.pop(0)
        print("")
        print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        print("")
        print(question)
        messages.append({"role": "user", "content": question})

        completion = client.chat.completions.create(
            model=the_model,
            messages=messages
        )

        ai_resp = completion.choices[0].message.content.strip()
        messages.append({"role": "assistant", "content": ai_resp})
        print("")
        print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
        print("")
        print(ai_resp)

    return ai_resp, messages

### Load past search data for speedup

In [7]:
# load the clip model
device = "cuda" if torch.cuda.is_available() else "cpu"

if model_name == "ViT-B/32":
    clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
else:
    clip_model, nuttin, clip_preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    clip_model = clip_model.to(device)

# Attempt to load the database
num_model_loads = 0
existing_image_features = None
existing_text_features = None
list_of_all_models = { }
existing_searches_done = []
index_to_model_uid = []
num_models_downloaded = 0
stack_existing_image_features = True

model_name_for_file = model_name.replace("/", "_")
saved_model_data_filename = f'./saved_model_data_{model_name_for_file}.json'
existing_image_features_filename = f'./existing_image_features_{model_name_for_file}.pt'
existing_text_features_filename = f"./existing_text_features_{model_name_for_file}.pt"

try:
    with open(saved_model_data_filename, 'r') as json_file:
        saved_model_data = json.load(json_file)
        existing_searches_done = saved_model_data["existing_searches_done"]
        list_of_all_models = saved_model_data["list_of_all_models"]
        index_to_model_uid = saved_model_data["index_to_model_uid"]
        print("index_to_model_uid", len(index_to_model_uid))
    existing_image_features = torch.load(existing_image_features_filename)
    print("existing_image_features", existing_image_features.shape)
    existing_text_features = torch.load(existing_text_features_filename)
    print("existing_text_features", existing_text_features.shape)
    assert existing_image_features.shape == existing_text_features.shape
except Exception as e:
    print(e)



index_to_model_uid 1235
existing_image_features torch.Size([1235, 1024])
existing_text_features torch.Size([1235, 1024])


  existing_image_features = torch.load(existing_image_features_filename)
  existing_text_features = torch.load(existing_text_features_filename)


In [8]:
def addSearchableModel(uid, image_url, name):
    global existing_image_features
    global existing_text_features
    global num_model_loads

    if not uid in list_of_all_models:
        response = requests.get(image_url)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content))

            if image.width < min_image_size or image.height < min_image_size:
                return
            
            tokenized_text = open_clip.tokenize(name).to(device)

            with torch.no_grad():
                new_text_features = clip_model.encode_text(tokenized_text).detach()
                
            # Normalize the features (optional but recommended)
            new_text_features = new_text_features / new_text_features.norm(dim=-1, keepdim=True)
            new_text_features_np = new_text_features.to(device).detach()

            if existing_text_features is not None:
                existing_text_features = torch.cat([existing_text_features.to(device), new_text_features_np], dim=0)
            else:
                existing_text_features = new_text_features_np
            
            new_text_features_np.detach()
            del new_text_features
            del new_text_features_np

            # image checking       
            image_tensor = clip_preprocess(image).to(device).detach().unsqueeze(0)
            
            # image_batch = torch.cat(image_tensor).to(device)  # Combine images into a batch
            image_features = clip_model.encode_image(image_tensor).to(device).detach()
            image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize features
            new_image_features_np = image_features.to(device).detach()

            if existing_image_features is not None:
                existing_image_features = torch.cat([existing_image_features.to(device), new_image_features_np], dim=0)
            else:
                existing_image_features = new_image_features_np
            
            list_of_all_models[uid] = { "name": name, "image": image_url, "uid": uid }
            index_to_model_uid.append(uid)
            num_model_loads += 1
            
            del new_image_features_np
            del image
            del image_tensor
            del image_features

            # Trigger garbage collection and print uncollectable objects
            gc.collect()

            torch.cuda.empty_cache()  # Clear GPU memory
            print("addSearchableModel", name, uid, image_url)
            
            with open(saved_model_data_filename, "w") as f:
                saved_model_data = { "existing_searches_done": existing_searches_done, "list_of_all_models": list_of_all_models, "index_to_model_uid": index_to_model_uid }
                json.dump(saved_model_data, f)
            torch.save(existing_image_features.cpu().detach(), existing_image_features_filename)
            torch.save(existing_text_features.cpu().detach(), existing_text_features_filename)
        else:
            print(f"Failed to fetch {image_url}. Status code: {response.status_code}") 

In [9]:
modelsDownloaded = { }

# respect users who don't want their models used in any any related things
def hasNoAITag(model):
    if model["tags"]:
        for tagIter in model["tags"]:
            if "noai" in tagIter["name"].lower():
                return True
    return False

def getModelList(objectName):

    try:
        # Specify the URL you want to request
        url = "https://api.sketchfab.com/v3/search?type=models&q=" + objectName + "&downloadable=true&min_face_count=" + str(minModelFaces) +  "&max_face_count=" + str(maxModelFaces) +  "&archives_flavours=false"
        print("url", url)

        # Send an HTTPS GET request
        response = requests.get(url)
        if response.status_code == 429:
            print("response.status_code == 429 sleeping for a minute")
            return ""    

        # Check if the request was successful (status code 200)
        if response.status_code == 200:
            data = json.loads(response.text)

            try:
                response = requests.get(url)
                if response.status_code == 200:
                    data = json.loads(response.text)
                    return_data = []
                    if data and data["results"] and data["results"][0]:
                        for iter in data["results"]:
                            if not hasNoAITag(iter):
                                return_data.append(iter)
                                addSearchableModel(iter["uid"], iter["thumbnails"]["images"][0]["url"], iter["name"])
                                
                        return return_data
                    
                else:
                    print(f"Failed to fetch {url}. Status code: {response.status_code}")
            except Exception as e:
                print("except Exception as e", e)
    except Exception as e:
        print("except Exception as e", e)

    return []

def download_one_model(model):
    global num_model_loads

    objectName = model["name"]
    modelUID = model["uid"]
    baseDir = model_download_path
    print("os.mkdir(baseDir): " + baseDir)

    if not os.path.exists("zipped_models"):
        os.mkdir("zipped_models")
    if not os.path.exists(baseDir):
        os.mkdir(baseDir)

    unzipped_file_path = baseDir + os.sep + modelUID
    
    if os.path.exists(unzipped_file_path):
        print("os.path.exists(unzipped_file_path): " + unzipped_file_path)
        modelsDownloaded[objectName] = modelUID
        return modelUID

    zipped_file_path = "zipped_models" + os.sep + modelUID + ".zip"

    if os.path.exists(zipped_file_path):
        with zipfile.ZipFile(zipped_file_path, 'r') as zip_ref:
            zip_ref.extractall(unzipped_file_path)
            print("os.path.exists(file_path): " + zipped_file_path)
        modelsDownloaded[objectName] = modelUID
        
        return modelUID

    url = "https://api.sketchfab.com/v3/models/" + model["uid"] + "/download"
    headers = {
        'Authorization': f"Token {sketchfab_api_token}"
    }
    params = {
        'mode': 'cors'
    }
    
    response = requests.get(url, headers=headers, params=params)
    modelData = json.loads(response.text)
            
    # Specify the URL of the file you want to download
    print("modelData", modelData, "response", response)

    if not "gltf" in modelData:
        return None

    file_url = modelData["gltf"]['url']
    print(file_url)

    # Send a GET request to the URL
    response = requests.get(file_url)

    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        # Get the content of the response and save it to a local file
        with open(zipped_file_path, "wb") as file:
            file.write(response.content)
            
            # Open the zip file
            with zipfile.ZipFile(zipped_file_path, 'r') as zip_ref:
                
                zip_ref.extractall(unzipped_file_path)
                print("Extracted", unzipped_file_path)
                modelsDownloaded[objectName] = modelUID
                
            num_model_loads += 1
            return modelUID
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")
        
def zscore(similarity):
    return (similarity - similarity.mean()) / similarity.std()

def load_model_from_vec(objectName, num_top_models = 50):

    pos_prompts = [objectName, "cohesive attractive design", 'high poly high quality', 'attractive and pretty colors', 'cohesive single object', 'object with a solid base']
    pos_weights=[1.1, 0.2, 0.2, 0.2, 0.4, 0.2]
 
    neg_prompts=[f"object in a room", "a complete room", 'ugly or low quality', 'low poly', 'abstract and disconnected', 'black and white', 'greyscale'] 
    neg_weights=[0.3, 0.3, 0.3, 0.3, 0.2, 0.1, 0.1]
    
    print("load_model_from_vec - Positive Prompts:", pos_prompts)
    print("load_model_from_vec - Negative Prompts:", neg_prompts)
    
    # ----- Compute positive embeddings -----
    pos_tokens = clip.tokenize(pos_prompts).to(device)
    with torch.no_grad():
        pos_embeddings = clip_model.encode_text(pos_tokens).to(device)
        pos_embeddings = pos_embeddings / pos_embeddings.norm(dim=-1, keepdim=True)
    
    # Weight and combine the positive embeddings.
    weighted_pos = torch.stack([w * emb for w, emb in zip(pos_weights, pos_embeddings)], dim=0)
    combined_pos = weighted_pos.sum(dim=0)
    
    # ----- Compute negative embeddings (if any) -----
    if neg_prompts:
        neg_tokens = clip.tokenize(neg_prompts).to(device)
        with torch.no_grad():
            neg_embeddings = clip_model.encode_text(neg_tokens).to(device)
            neg_embeddings = neg_embeddings / neg_embeddings.norm(dim=-1, keepdim=True)
        weighted_neg = torch.stack([w * emb for w, emb in zip(neg_weights, neg_embeddings)], dim=0)
        combined_neg = weighted_neg.sum(dim=0)
    else:
        # If no negative prompts are provided, use a zero tensor of the same shape.
        combined_neg = torch.zeros_like(combined_pos)
    
    # ----- Combine positive and negative embeddings -----
    combined_text = combined_pos - combined_neg
    combined_text = combined_text / combined_text.norm()
    
    # ----- Compute similarity with existing image features -----
    # (Assuming existing_image_features is a [N x D] tensor where D is the embedding dim.)
    combined_similarity = (existing_image_features @ combined_text.T).squeeze()
    combined_similarity = zscore(combined_similarity.cpu().numpy())

    top_X_indices = np.argsort(-combined_similarity)[:num_top_models]
    top_X_models = [list_of_all_models[index_to_model_uid[idx]] for idx in top_X_indices]
    
    gc.collect()  # Run garbage collection
    torch.cuda.empty_cache()  # Clear GPU memory

    return top_X_models


def downloadModel(objectName, num_top_models = 5):
    
    room_plus_object = objectName

    if objectName in modelsDownloaded:
        print("return modelsDownloaded[objectName]", objectName)
        download_one_model({ "uid": modelsDownloaded[objectName], 'name': objectName})
        return modelsDownloaded[objectName]

    if objectName in existing_searches_done or freeze_model_database_models:
        return load_model_from_vec(room_plus_object, num_top_models)
    existing_searches_done.append(objectName)

    getModelList(objectName)
    modelTokens = objectName.split(' ')
    if len(modelTokens) > 5:
        subDesc = askMultiQuestion([f"If I was searching in a model database for \"{objectName}\" and could only search using five words, which words should I use? Avoid using plural words.", "Say the words.  Just say five words.  Don't say anything else.  Just five.  That's it."])[0].strip().replace("\"", "").replace("\'","")
        if subDesc not in existing_searches_done:
            existing_searches_done.append(subDesc)
            getModelList(subDesc)

    if len(modelTokens) > 4:
        subDesc = askMultiQuestion([f"If I was searching in a model database for \"{objectName}\" and could only search using four words, which words should I use? Avoid using plural words.", "Say the words.  Just say four words.  Don't say anything else.  Just four.  That's it."])[0].strip().replace("\"", "").replace("\'","")
        if subDesc not in existing_searches_done:
            existing_searches_done.append(subDesc)
            getModelList(subDesc)

    if len(modelTokens) > 3:
        subDesc = askMultiQuestion([f"If I was searching in a model database for \"{objectName}\" and could only search using three words, which words should I use? Avoid using plural words.", "Say the words.  Just say three words.  Don't say anything else.  Just three.  That's it."])[0].strip().replace("\"", "").replace("\'","")
        if subDesc not in existing_searches_done:
            existing_searches_done.append(subDesc)
            getModelList(subDesc)

    if len(modelTokens) > 2:
        subDesc = askMultiQuestion([f"If I was searching in a model database for \"{objectName}\" and could only search using two words, which words should I use? Avoid using plural words.", "Say the words.  Just say two words.  Don't say anything else.  Just two.  That's it."])[0].strip().replace("\"", "").replace("\'","")
        if subDesc not in existing_searches_done:
            existing_searches_done.append(subDesc)
            getModelList(subDesc)
    
    for subDesc in modelTokens:
        if subDesc not in existing_searches_done:
            existing_searches_done.append(subDesc)
            getModelList(subDesc)

    return load_model_from_vec(room_plus_object, num_top_models)

In [10]:
model_uids = downloadModel(modelToDownload)

load_model_from_vec - Positive Prompts: ['Silver car', 'cohesive attractive design', 'high poly high quality', 'attractive and pretty colors', 'cohesive single object', 'object with a solid base']
load_model_from_vec - Negative Prompts: ['object in a room', 'a complete room', 'ugly or low quality', 'low poly', 'abstract and disconnected', 'black and white', 'greyscale']


  combined_similarity = (existing_image_features @ combined_text.T).squeeze()


In [11]:
print(modelToDownload)
for model in model_uids:
    print(model)

Silver car
{'name': 'Mercedes-Benz Silver Lightning', 'image': 'https://media.sketchfab.com/models/cfba13ac9ad6440db68d6a65c22a01ed/thumbnails/47f98454fcd94794b680de12082b4724/d858ecb090c44f14bb48eb659f9dda52.jpeg', 'uid': 'cfba13ac9ad6440db68d6a65c22a01ed'}
{'name': '2001 Mercedes-Benz A160 Hatchback', 'image': 'https://media.sketchfab.com/models/2cff998e65d44480af9d021a003ef000/thumbnails/c379468970af45ae87d7e12202018f52/ce8c462711e24101afec9650db66e723.jpeg', 'uid': '2cff998e65d44480af9d021a003ef000'}
{'name': 'Sport Car', 'image': 'https://media.sketchfab.com/models/5bf9eb1870594f91a937e8cc2eb922f9/thumbnails/1ddfe8bef1ab4618b31ed864c613bf8a/219a3e6332ca43d6b85dd6a45e02a572.jpeg', 'uid': '5bf9eb1870594f91a937e8cc2eb922f9'}
{'name': 'Kazē Strike fighter (The Koryu Unity)', 'image': 'https://media.sketchfab.com/models/50d5f90d49b44f7589d19b620a625c78/thumbnails/03d46510208f4a5c959d06193a965c33/19d19a379a044ca48a7dc92f8f8886e6.jpeg', 'uid': '50d5f90d49b44f7589d19b620a625c78'}
{'name':

In [12]:
ideal_dataset = [
    { "search": "Wall Lamp in a cozy bedroom", "uid": "ac14ea025c754e45a78306903f595a0c" },
    { "search": "chandelier in a fancy grand ballroom", "uid": "94a3c2f9d503434b87120f05e937db62" },
    { "search": "Dragon's Tooth with glowing red streaks", "uid": "f5fc6e1cd95f44de8319ddb1a2bedb26" },
    { "search": "Tree growing out of an acid swamp", "uid": "feb56a05db3546d390d0957e03cf48a3" },
    { "search": "Glittering Blue Geode", "uid": "22ff9332332542489dae0076180f7102" },
    { "search": "Tree in a cartoon or anime style", "uid": "0a91ecb8cd954fc19cca1a367f91555d" },
    { "search": "A robot dog carrying a bone in it's mouth", "uid": "22815eb3a5724962bb72db0b0570c097" },
    { "search": "A large troll carrying a club made of bones", "uid": "c8d0996144a848b8975727257427473f" },
    { "search": "An alien carrying a ray gun", "uid": "4979ba26b80d407f80cafc9fc5997f37" },
    { "search": "Large Ice sculpture of a figure skater", "uid": "9c69ffc8a44a431f9cfdf0eb78d82b8f" },
    { "search": "Christmas tree with a shining star at the top", "uid": "2683bbd6f802480cb36c901d357ff119" },
    { "search": "A dumptruck with yellow paint", "uid": "4dd5b3f4a84d473f9c73f2a1d43667ea" },
    { "search": "small cottage surrounded by grass", "uid": "816a9104b90d4a449f5162f7c144a54a" },
    { "search": "Silver car", "uid": "5bf9eb1870594f91a937e8cc2eb922f9" },
]
print(len(ideal_dataset))

14


In [13]:
# freeze_model_database_models = True
drift_score = 0
max_score_per_item = 10
for item in ideal_dataset:
    tests_uids = downloadModel(item["search"], max_score_per_item)
    print(tests_uids)
    for model in tests_uids:
        if model['uid'] == item['uid']:
            break
        drift_score += 1
    print("drift_score", drift_score)

print("drift_score", drift_score)

load_model_from_vec - Positive Prompts: ['Wall Lamp in a cozy bedroom', 'cohesive attractive design', 'high poly high quality', 'attractive and pretty colors', 'cohesive single object', 'object with a solid base']
load_model_from_vec - Negative Prompts: ['object in a room', 'a complete room', 'ugly or low quality', 'low poly', 'abstract and disconnected', 'black and white', 'greyscale']
[{'name': 'Retro/Vintage Wall Lamp', 'image': 'https://media.sketchfab.com/models/ac14ea025c754e45a78306903f595a0c/thumbnails/7892261363a74dbeafd521de47e38b3b/0879165e78844e84a32dcf43da4da53d.jpeg', 'uid': 'ac14ea025c754e45a78306903f595a0c'}, {'name': 'CC0 - Wall Lamp', 'image': 'https://media.sketchfab.com/models/e619e1b84af14194ac739712de4330b2/thumbnails/73091eb7f58145c99bd001a8b5490e88/f8b1d5ee6c8f40548dc709994e8ebf5f.jpeg', 'uid': 'e619e1b84af14194ac739712de4330b2'}, {'name': 'Unseen: Wall Lamp', 'image': 'https://media.sketchfab.com/models/dda408533ef54915b69d90f230b5e929/thumbnails/fe3515c2234448

In [14]:
print("drift_score", drift_score)
max_score = len(ideal_dataset) * max_score_per_item
print("accuracy", (max_score - drift_score) / max_score)

drift_score 40
accuracy 0.7142857142857143


In [15]:
print("saved_model_data_filename", saved_model_data_filename)
print("existing_image_features_filename", existing_image_features_filename)
print("existing_text_features_filename", existing_text_features_filename)

if num_model_loads > 0 and stack_existing_image_features and not freeze_model_database_models:
    with open(saved_model_data_filename, "w") as f:
        saved_model_data = { "existing_searches_done": existing_searches_done, "list_of_all_models": list_of_all_models, "index_to_model_uid": index_to_model_uid }
        json.dump(saved_model_data, f)
    torch.save(existing_image_features.cpu().detach(), existing_image_features_filename)
    torch.save(existing_text_features.cpu().detach(), existing_text_features_filename)
    print("num_model_loads", num_model_loads, "list_of_all_models", len(list_of_all_models))
else:
    print("none saved")
    if freeze_model_database_models:
        print("model database frozen")

saved_model_data_filename ./saved_model_data_ViT-H_14-378.json
existing_image_features_filename ./existing_image_features_ViT-H_14-378.pt
existing_text_features_filename ./existing_text_features_ViT-H_14-378.pt
none saved
model database frozen
