# Set up

In [None]:
import os
import shutil
import time
import sys
import json

In [None]:
import os
HOME = os.getcwd()

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# change "/content/gdrive/MyDrive/"  to "/mydrive so you can use directly /mydrive"
!ln -s /content/gdrive/MyDrive/ /mydrive

Mounted at /content/gdrive


In [None]:
import shutil

In [None]:
# print(len(os.listdir('heineken-images')))

# Deploy ResNet50

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained ResNet50 model
resnet50 = models.resnet50(pretrained=True).to(device)
resnet50.eval()

# Remove the last layer (classification layer)
model = nn.Sequential(*list(resnet50.children())[:-1]).to(device)
model.eval()

# Define image preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 116MB/s]


In [None]:
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = preprocess(image).unsqueeze(0).to(device)

    # torch.Tensor (1, C, W, H)
    return image

def get_embedding(image_path):
    image = preprocess_image(image_path)
    with torch.no_grad():
        embedding = model(image).cpu().numpy().flatten()

    # numpy.ndarray (1 dim)
    return embedding


# Vector DB

In [None]:
from collections import Counter
import sys, time, os
from dataclasses import dataclass, field
from typing import List, Tuple
from pydantic import BaseModel
import numpy as np
import pickle

@dataclass
class Record:
    source: str
    brand: str
    vector: np.ndarray

class VectorDatabase:
    def __init__(self):
        self.db = {
            "bottle": [],
            "can": [],
            "carton": [],
            "icebox": [],
            "icebucket": [],
            "standee": [],
            "banner": []
        }

    def add(self, filepath: str):
        filename = os.path.basename(filepath)
        if filename in [".DS_Store", ".ipynb_checkpoints"]:
            return

        # tiger.can.1.png
        brand, type, no, _ = filename.split('.')
        vector = get_embedding(filepath)

        if type == "logo":
            for key in self.db.keys():
                self.db[key].append(
                    Record(filepath, brand, vector)
                )
        else:
            if type not in self.db.keys():
                raise ValueError(f"Invalid type: {type}")
            self.db[type].append(
                Record(filepath, brand, vector)
            )

    def brand(self, type, vector, threshold: float = 0.5, top_k: int = 5) -> str:
        query_vector = vector.reshape(1, -1)  # Encode -> 2D array
        record_vectors = np.array([record.vector for record in self.db[type]])

        similarities = cosine_similarity(query_vector, record_vectors).flatten()
        top_indices: List[int] = similarities.argsort()[::-1]

        brand = self.db[type][top_indices[0]].brand

        # Filter out indices with similarity below the threshold
        top_indices = [idx for idx in top_indices if similarities[idx] > threshold]

        if not top_indices:
            return None  # No similar vectors above the threshold

        # Get the top_k indices
        top_k_indices = top_indices[:top_k]

        for idx in top_k_indices:
            print(similarities[idx], self.db[type][idx].brand, self.db[type][idx].source)

        # Get the brands of the top_k most similar vectors
        top_k_brands = [self.db[type][idx].brand for idx in top_k_indices]

        # Determine the most common brand using majority vote
        most_common_brand, _ = Counter(top_k_brands).most_common(1)[0]

        return most_common_brand


In [None]:
BASE_FOLDER = '/mydrive/base'

In [None]:
BRANDS = os.listdir(BASE_FOLDER)
BRANDS

['.DS_Store',
 'biaviet',
 'tiger',
 'larue',
 'edelweiss',
 'bivina',
 'strongbow',
 'heineken']

In [None]:
db = VectorDatabase()

Embed images

In [None]:
for brand in BRANDS:
    if brand == '.DS_Store':
        continue
    for filename in os.listdir(os.path.join(BASE_FOLDER, brand)):
        filepath = os.path.join(BASE_FOLDER, brand, filename)
        db.add(filepath)

In [None]:
import pickle as pkl
with open('vdb.pkl', 'wb') as f:
    pkl.dump(db, f)

In [None]:
TEST_FOLDER = '/mydrive/brands-tests'

In [None]:
test_records = []

for filename in os.listdir(TEST_FOLDER):
    true_brand, type,  *_ = filename.split('.')
    filepath = os.path.join(TEST_FOLDER, filename)
    brand = db.brand(type, get_embedding(filepath), top_k=5)

    test_records.append(
        (true_brand==brand, type, true_brand, brand, filepath)
    )
    print(f"EXPECTED: {true_brand:<20}GOT: {brand}")

0.9999999 tiger /mydrive/base/tiger/tiger.bottle.14.png
0.88013136 tiger /mydrive/base/tiger/tiger.bottle.12.png
0.8502703 tiger /mydrive/base/tiger/tiger.bottle.1.png
0.8298529 tiger /mydrive/base/tiger/tiger.bottle.6.png
0.8037513 tiger /mydrive/base/tiger/tiger.bottle.2.png
EXPECTED: tiger               GOT: tiger
0.9999999 tiger /mydrive/base/tiger/tiger.bottle.1.png
0.87543255 tiger /mydrive/base/tiger/tiger.bottle.2.png
0.8502703 tiger /mydrive/base/tiger/tiger.bottle.14.png
0.8348228 tiger /mydrive/base/tiger/tiger.bottle.12.png
0.7963221 tiger /mydrive/base/tiger/tiger.bottle.6.png
EXPECTED: tiger               GOT: tiger
0.81426775 tiger /mydrive/base/tiger/tiger.bottle.14.png
0.7876244 biaviet /mydrive/base/biaviet/biaviet.bottle.5.png
0.78752494 tiger /mydrive/base/tiger/tiger.bottle.12.png
0.7867473 tiger /mydrive/base/tiger/tiger.bottle.1.png
0.77587634 tiger /mydrive/base/tiger/tiger.bottle.3.png
EXPECTED: edelweiss           GOT: tiger
1.0000002 tiger /mydrive/base/tiger

In [None]:
trues = sum([1 for record in test_records if record[0]])
print(f"Accuracy: {trues}/{len(test_records)}")

Accuracy: 92/114
