# Metal Band Logo Identifier

This script is the final step of my band logo identifier app. If you hit 'run all' and go down to the bottom, there will be a link to a Gradio applet where you can upload an image of a band logo, and my model will try to guess what it is. It works by guessing the three bands it deems 'closest' to the input logo, bands on cosine similarity; the bands are listed in order of likelihood.

![](https://drive.google.com/uc?export=view&id=1uK2TdenSsdZ1rBB334NJ3U0kZidgbfII)



In [None]:
# import necessary packages
import os
import io
import torch
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display
from torchvision.transforms import v2
from transformers import CLIPProcessor, CLIPModel
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from IPython.display import display, Image as IPImage, clear_output

In [None]:
# path to my model files
model_dir = "/content/drive/MyDrive/Logos/final_model"

# load the model and processor from the directory
model = CLIPModel.from_pretrained(model_dir)
processor = CLIPProcessor.from_pretrained(model_dir)

# send the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# initialize the processor (this handles image and text preprocessing for CLIP)
processor = CLIPProcessor.from_pretrained(model_dir)

def get_image_embedding(image):
    if not isinstance(image, Image.Image):
        raise TypeError("Input must be a PIL Image.")

    # get the image ready for the model and save it to device
    # the processor returns a dictionary
    inputs = processor(images=image, return_tensors='pt', padding=True).to(device)

    # make it so you don't need to save gradients
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)

    # return the image features, save it to CPU, convert it to a numpy array,
    # and flatten it to a vector
    return image_features.cpu().numpy().flatten()

def image_paths(local_download_path):
  file_paths = []

  for filename in os.listdir(local_download_path):
      if filename.endswith(("jpg", "png")):
          file_paths.append(local_download_path + filename)

  return file_paths

# the vector db was formed by going through each photo in the drive and embedding it
# therefore, we'll take the link to the photo to be its label
# the way we can extract the correct band id
img_paths = image_paths('/content/drive/MyDrive/Logos/Small_Dataset/')
image_embeddings = np.load('/content/drive/MyDrive/Logos/vector_db_correct.npy')
labels = img_paths

# helper function to strip the band's ID from the file name
def id_stripper(file_name, file_path):
  band_id = []
  i = len(file_path)
  while file_name[i] != '_':
    band_id.append(file_name[i])
    i+=1
  return ''.join(band_id)


# get the band ids
ids = [id_stripper(path, '/content/drive/MyDrive/Logos/Small_Dataset/') for path in img_paths]
small_df = pd.DataFrame({})
small_df['ID'] = ids
small_df['Embedding'] = image_embeddings.tolist()

def mean_embedding(group):
    return np.mean(np.vstack(group), axis=0)

# instead of having an embedding for each image, want to average over all images corresponding to each band
mean_df = small_df.groupby('ID')['Embedding'].apply(mean_embedding).reset_index()
mean_embeddings = np.vstack(mean_df['Embedding'])
mean_labels = mean_df['ID']

# now fitting a nearest neighbours model to find the most similar
# logo based on cosine similarity
nn = NearestNeighbors(n_neighbors = 3, metric = 'cosine')
nn.fit(mean_embeddings)

def query_new_image(img_path):
  new_img_embedding = get_image_embedding(img_path)

  _, indices = nn.kneighbors([new_img_embedding])
  response = [mean_labels[indices[0][i]] for i in range(3)]
  return response

#import the dataset
bands_df =  pd.read_csv('/content/drive/MyDrive/metal_bands_roster.csv', low_memory=False)

# have to add 200 Stab Wounds because they weren't included in the original dataframe
SW_200 = { 'Band ID': '3540465014', 'Name': '200 Stab Wounds',
          'URL': 'https://www.metal-archives.com/bands/200_Stab_Wounds/3540465014',
           'Country': 'United States', 'Genre' : 'Death Metal',
           'Status': 'Active', 'Photo_URL' : 'https://www.metal-archives.com/images/3/5/4/0/3540465014_photo.jpg?4349',
           'Label ID': '3'}
bands_df.loc[len(bands_df)] = SW_200

In [None]:
!pip install gradio



In [None]:
import gradio as gr

def gradio_model_guess(query_img):
    try:
      model_guesses = query_new_image(query_img)
      humanized = []
      for i in range(3):
        id = model_guesses[i]
        row = bands_df[bands_df.isin([id]).any(axis=1)]
        humanized.append((row['Name'].iloc[0], row['URL'].iloc[0]))
      response = (f"That could be **{humanized[0][0]}**, **{humanized[1][0]}**, or **{humanized[2][0]}**. Here are the links to their Metallum pages for you to check out!\n\n"
          f"[{humanized[0][0]}]({humanized[0][1]})<br> "
          f"[{humanized[1][0]}]({humanized[1][1]})<br>"
          f"[{humanized[2][0]}]({humanized[2][1]})")
      return response


    except Exception as e:
        print(f"Error in gradio_model_guess: {e}")
        return str(e)


# create the Gradio interface
iface = gr.Interface(
    fn=gradio_model_guess,
    inputs=gr.Image(type="pil"),
    outputs=gr.Markdown(label="Answer Markdown"),
    live=True,
    title="Band Logo Recognition",
    description="Upload a logo image and I'll try to guess which band it belongs to based on my model's predictions."
)

# launch the interface
iface.launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://64759a58dbc0a3fac5.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


