In [12]:
import torch
import clip
import requests
import pandas as pd
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics import accuracy_score, f1_score, precision_score,recall_score
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
tqdm.pandas()

# CLIP Zero-shot

In [None]:
device = "cuda:7" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

## Get label list

In [13]:
# Get the label list for the dataset
all_data = "/local/xiaowang/Ingredient_Rec/1mtest_all_440_labels.json"
all_df = pd.read_json(all_data)
print(all_df.shape) # (60485, 10)
label_list = all_df["cleaned_ingredients"].explode().unique() # length 44

(60485, 10)


## Change unique key

In [14]:
# define unique key as image_file_name
test_data = "/local/xiaowang/Ingredient_Rec/Dataset/1mtest_test_440_labels.json"
test_df = pd.read_json(test_data)
print(test_df.shape) #(6048, 10)

test_exploded_df = test_df.explode("image_file_name_ls")
print(test_exploded_df.shape) # (13584, 10)
test_exploded_df.head(2)

(6048, 10)
(13584, 10)


Unnamed: 0,id,images,image_file_name_ls,ingredients,url,partition,title,instructions,extracted_ingredients,cleaned_ingredients
29552,7cd5228ac1,"[{'id': '877a0ba43d.jpg', 'url': 'https://img-...",877a0ba43d.jpg,"[{'text': '1 1/2 sleeves saltines'}, {'text': ...",https://cookpad.com/us/recipes/336742-lemon-be...,test,Lemon Beach Pie,[{'text': 'Crush or process crackers until fin...,"[saltin, butter, sugar, condens milk, egg yolk...","[butter, sugar, milk, egg, yolk, lemon, cream,..."
479,0226b5df5f,"[{'id': '0e3d44f3c6.jpg', 'url': 'http://tasty...",0e3d44f3c6.jpg,"[{'text': '3 whole Ripe Bananas, Crushed'}, {'...",http://tastykitchen.com/recipes/breads/banana-...,test,Banana Bread by Vicki,[{'text': 'Preheat oven to 275 degrees F. Mix ...,"[ripe banana, butter, margarin, sugar, egg, fl...","[banana, butter, sugar, egg, flour, powder, so..."


## Image and Ingredients Demo

In [4]:
# 1. get image file path
image_folder = '/local/xiaowang/Ingredient_Rec/1m_test'
sample_image_file_name = test_exploded_df.iloc[0]['image_file_name_ls']


def image_file_name_to_path(dataset_dir, image_file_name):
    """Convert image file name to full path"""
    sub_dir_path = ""
    for sub_dir in range(4):
        sub_dir_path += image_file_name[sub_dir] + "/"
    image_path = dataset_dir + "/" + sub_dir_path + image_file_name
    return image_path

sample_im_path = image_file_name_to_path(image_folder, sample_image_file_name)
print(sample_im_path)

/local/xiaowang/Ingredient_Rec/1m_test/8/7/7/a/877a0ba43d.jpg


In [None]:
print(test_exploded_df.iloc[0]['cleaned_ingredients'])
Image.open(sample_im_path)

## Evaluation Utils

In [18]:
def get_pred_ohe(pred_idx_ls: list, ingredient_list: list) -> np.ndarray:
    """
    Get the one-hot encoding of the predicted ingredients
    """
    ingredient_set = set(ingredient_list)
    ingredient_set_indices = {} 
    for idx, label in enumerate(ingredient_set):
        ingredient_set_indices[label] = idx
    temp = [0] * len(ingredient_set)
    pred_ingre_ls = [ingredient_list[idx] for idx in pred_idx_ls]
    for pred in pred_ingre_ls:
        idx = ingredient_set_indices.get(pred, -1)
        if idx != -1:
            temp[idx] = 1
    ohe = np.array(temp)
    return ohe

def get_true_ohe(true_ingre_ls: list, ingredient_list: list) -> np.ndarray:
    """
    Get the one-hot encoding of the true ingredients
    """
    ingredient_set = set(ingredient_list)
    ingredient_set_indices = {} 
    for idx, label in enumerate(ingredient_set):
        ingredient_set_indices[label] = idx
    temp = [0] * len(ingredient_set)
    for ingre in true_ingre_ls:
        idx = ingredient_set_indices.get(ingre, -1)
        if idx != -1:
            temp[idx] = 1
    ohe = np.array(temp)
    return ohe

def get_metrics(ingredient_ls, top_k_idx_ls, dataset):
    """
    Get the metrics for the accusation prediction
    """
    # ohe for pred dataset
    metrics_dict = {}
    pred_data_ohe = []
    for top_k_idx in top_k_idx_ls:
        pred_ohe = get_pred_ohe(top_k_idx, ingredient_ls)
        pred_data_ohe.append(pred_ohe)
    pred_data_ohe = np.array(pred_data_ohe)

    # ohe for true dataset
    true_data_ohe = []
    for i in range(len(test_exploded_df)):
        true_ohe = get_true_ohe(test_exploded_df.iloc[i]['cleaned_ingredients'], label_list)
        true_data_ohe.append(true_ohe)
    true_data_ohe = np.array(true_data_ohe)

    # get the metrics 
    acc = accuracy_score(true_data_ohe, pred_data_ohe)
    mi_f1 = f1_score(true_data_ohe, pred_data_ohe, average='micro')
    ma_f1 = f1_score(true_data_ohe, pred_data_ohe, average='macro')
    mi_precision = precision_score(true_data_ohe, pred_data_ohe, average='micro')
    ma_precision = precision_score(true_data_ohe, pred_data_ohe, average='macro')
    mi_recall = recall_score(true_data_ohe, pred_data_ohe, average='micro')
    ma_recall = recall_score(true_data_ohe, pred_data_ohe, average='macro')

    metrics_dict = {
        'accuracy': acc, 
        'micro_f1': mi_f1, 
        'macro_f1': ma_f1, 
        'micro_precision': mi_precision, 
        'macro_precision': ma_precision, 
        'micro_recall': mi_recall, 
        'macro_recall': ma_recall
        }
    return metrics_dict


## Dataloader

In [None]:
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor

class IngredientRecgDataset(Dataset):
    def __init__(self, df, all_ingredients_ls, transform=None):
        self.data = df
        self.all_ingredients_ls = all_ingredients_ls
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        img_file = image_file_name_to_path(image_folder, row['image_file_name_ls'])
        image = Image.open(img_file)
        image = preprocess(image).unsqueeze(0)

        true_ingredient_ls = row['cleaned_ingredients']
        true_ingredient_ohe = get_true_ohe(true_ingredient_ls, self.all_ingredients_ls)

        if self.transform:
            image = self.transform(image)

        return image, true_ingredient_ohe

batch_size = 16

test_dataset = IngredientRecgDataset(test_exploded_df, label_list)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


## Zero-shot Inference

In [None]:
def labels_to_prompts(labels: list[str]) -> list[str]:
    prompt = "Recipe contains ingredient: {}"
    return [prompt.format(label) for label in labels]


clip_score_ls = []
top_k_idx_ls = []
for imgs, true_ingredient_ls in tqdm(test_loader):
    imgs = [img.to(device) for img in imgs]
    imgs = torch.cat(imgs, dim = 0)
    prompts = labels_to_prompts(label_list)
    text = clip.tokenize(prompts).to(device)
    with torch.no_grad():
        image_features = model.encode_image(imgs) # [batch_size, 512]
        image_features = image_features.unsqueeze(1) # [batch_size, 1, 512]
        text_features = model.encode_text(text) # [label_list, 512]
        cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
        clip_scores = cos(image_features, text_features).cpu().numpy() # [batch_size, label_list]
        clip_score_ls.append(clip_scores)
        top_k_idx_ls.extend([clip_score.argsort()[-10:] for clip_score in clip_scores]) # k = 10

## Save / Read Results

In [1]:
import pickle

# with open('top_k_idx_ls.pkl', 'wb') as f:
#     pickle.dump(top_k_idx_ls, f)
# print(len(top_k_idx_ls))

# with open('clip_score_ls.pkl', 'rb') as f:
#     pickle.dump(clip_score_ls, f)


# open top_k_idx_ls.pkl
with open('top_k_idx_ls.pkl', 'rb') as f:
    top_k_idx_ls = pickle.load(f)

# open clip_score_ls.pkl
with open('clip_score_ls.pkl', 'rb') as f:
    clip_score_ls = pickle.load(f)

In [3]:
clip_score_np_ls = [clip_score.cpu().numpy() for clip_score in clip_score_ls]

In [10]:
top_k_np_idx_ls = [clip_score.argsort()[-10:] for batch_clip_score in clip_score_np_ls for clip_score in batch_clip_score]

In [11]:
len(top_k_np_idx_ls)

13584

## Evaluation

In [32]:
zs_results = get_metrics(label_list, top_k_np_idx_ls, test_exploded_df)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [33]:
zs_results

{'accuracy': 0.0,
 'micro_f1': 0.05592504890353135,
 'macro_f1': 0.04997815061493393,
 'micro_precision': 0.04998527679623086,
 'macro_precision': 0.08102703958913474,
 'micro_recall': 0.06346684114595504,
 'macro_recall': 0.17494958233718533}