### Imports

In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import random
import tqdm

import pandas as pd

### Read the dataset

In [227]:
df_recipies = pd.read_csv("./../dataset/MM-Food-100K.csv")

print(f"Number of recipes: {len(df_recipies)}")
df_recipies.head()

Number of recipes: 100000


Unnamed: 0,image_url,camera_or_phone_prob,food_prob,dish_name,food_type,ingredients,portion_size,nutritional_profile,cooking_method,sub_dt
0,https://file.b18a.io/7843322356500104680_44354...,0.7,0.95,Fried Chicken,Restaurant food,"[""chicken"",""breading"",""oil""]","[""chicken:300g""]","{""fat_g"":25.0,""protein_g"":30.0,""calories_kcal""...",Frying,20250704
1,https://file.b18a.io/7833227147700100732_67487...,0.7,1.0,Pho,Restaurant food,"[""noodles"",""beef"",""basil"",""lime"",""green onions...","[""noodles:200g"",""beef:100g"",""vegetables:50g""]","{""fat_g"":15.0,""protein_g"":25.0,""calories_kcal""...",boiled,20250702
2,https://file.b18a.io/7832600581600103585_26423...,0.8,0.95,Pan-fried Dumplings,Restaurant food,"[""dumplings"",""chili oil"",""soy sauce""]","[""dumplings:300g"",""sauce:50g""]","{""fat_g"":15.0,""protein_g"":20.0,""calories_kcal""...",Pan-frying,20250625
3,https://file.b18a.io/7839056601700101188_98515...,0.7,1.0,Bananas,Raw vegetables and fruits,"[""Bananas""]","[""Bananas: 10 pieces (about 1kg)""]","{""fat_g"":3.0,""protein_g"":12.0,""calories_kcal"":...",Raw,20250718
4,https://file.b18a.io/7837642737500100261_17312...,0.8,0.9,Noodle Stir-Fry,Restaurant food,"[""noodles"",""chicken"",""vegetables"",""sauce""]","[""noodles:300g"",""chicken:100g"",""vegetables:50g""]","{""fat_g"":20.0,""protein_g"":25.0,""calories_kcal""...",stir-fried,20250711


### Pre-processing the dataset

In [228]:
# Convert the ingredients to list
df_recipies["ingredients"] = df_recipies["ingredients"].str.strip("[]").str.replace('"', '')

In [229]:
# Drop unnecessary columns
df_recipies.drop(["camera_or_phone_prob", "food_prob", "sub_dt"], axis=1, inplace=True)

In [230]:
df_recipies_new.head()

Unnamed: 0,image_url,camera_or_phone_prob,food_prob,dish_name,food_type,ingredients,portion_size,nutritional_profile,cooking_method,sub_dt
0,https://file.b18a.io/7843322356500104680_44354...,0.7,0.95,Fried Chicken,Restaurant food,"chicken,breading,oil","[""chicken:300g""]","{""fat_g"":25.0,""protein_g"":30.0,""calories_kcal""...",Frying,20250704
1,https://file.b18a.io/7833227147700100732_67487...,0.7,1.0,Pho,Restaurant food,"noodles,beef,basil,lime,green onions,chili","[""noodles:200g"",""beef:100g"",""vegetables:50g""]","{""fat_g"":15.0,""protein_g"":25.0,""calories_kcal""...",boiled,20250702
2,https://file.b18a.io/7832600581600103585_26423...,0.8,0.95,Pan-fried Dumplings,Restaurant food,"dumplings,chili oil,soy sauce","[""dumplings:300g"",""sauce:50g""]","{""fat_g"":15.0,""protein_g"":20.0,""calories_kcal""...",Pan-frying,20250625
3,https://file.b18a.io/7839056601700101188_98515...,0.7,1.0,Bananas,Raw vegetables and fruits,Bananas,"[""Bananas: 10 pieces (about 1kg)""]","{""fat_g"":3.0,""protein_g"":12.0,""calories_kcal"":...",Raw,20250718
4,https://file.b18a.io/7837642737500100261_17312...,0.8,0.9,Noodle Stir-Fry,Restaurant food,"noodles,chicken,vegetables,sauce","[""noodles:300g"",""chicken:100g"",""vegetables:50g""]","{""fat_g"":20.0,""protein_g"":25.0,""calories_kcal""...",stir-fried,20250711


### Extract the ingredients vocab

In [231]:
ingredients_list = (
    df_recipies["ingredients"]
        .dropna()                 # drop missing cells
        .str.split(',')           # split each string on commas
        .explode()                # turn lists into rows
        .str.strip()              # trim spaces around each item
        .str.replace('"', "")     # replace " with nothing
        .loc[lambda s: s.ne('')]  # remove empty strings (if any)
        .tolist()                 # back to a single Python list
)
unique_ingredients_list = list(set(ingredients_list))
unique_ingredients_list

['Roasted chicken',
 'transparent noodles',
 'grilled fish',
 'omelet',
 'Peas',
 'McSpicy wings',
 'purple rice',
 'glass noodles',
 'rice wine',
 'chicken skin',
 'Breaded chicken',
 'raisin',
 'soft drinks',
 'spring roll pastry',
 'cranberry filling',
 'small fruit',
 'tofu',
 'cookie dough',
 'sweet potato noodles',
 'apricots',
 'green chili',
 'milk filling',
 'Fruit drink',
 'Salmon',
 'espresso',
 'Pork',
 'bean stew',
 'Caesar dressing',
 'blue rice',
 'Pomelo',
 'sugar snap peas',
 'green items',
 'peanut',
 'shells',
 'mulberries',
 'green jelly',
 'chestnuts',
 'Oreos',
 'Vitamins',
 'black tea extract',
 'unknown vegetables',
 'fried eggs',
 'salad greens',
 'scallop flavoring',
 'turnips',
 'Whipped cream',
 'sundried tomatoes',
 'cooking oil',
 'pounded yam',
 'aloe vera',
 'vegetable stir-fry',
 'dessert items',
 'tomato garnish',
 'ice cream base',
 'icing',
 'fried spring roll',
 'orange slices',
 'roasted grains',
 'cookie base',
 'organic milk',
 'wrapping dough',


### Convert ingredients to index

In [232]:
ingredient2idx = {ing: i for i, ing in enumerate(unique_ingredients_list)}

print(f"Number of unique ingredients: {len(unique_ingredients_list)}")
ingredient2idx

Number of unique ingredients: 4085


{'Roasted chicken': 0,
 'transparent noodles': 1,
 'grilled fish': 2,
 'omelet': 3,
 'Peas': 4,
 'McSpicy wings': 5,
 'purple rice': 6,
 'glass noodles': 7,
 'rice wine': 8,
 'chicken skin': 9,
 'Breaded chicken': 10,
 'raisin': 11,
 'soft drinks': 12,
 'spring roll pastry': 13,
 'cranberry filling': 14,
 'small fruit': 15,
 'tofu': 16,
 'cookie dough': 17,
 'sweet potato noodles': 18,
 'apricots': 19,
 'green chili': 20,
 'milk filling': 21,
 'Fruit drink': 22,
 'Salmon': 23,
 'espresso': 24,
 'Pork': 25,
 'bean stew': 26,
 'Caesar dressing': 27,
 'blue rice': 28,
 'Pomelo': 29,
 'sugar snap peas': 30,
 'green items': 31,
 'peanut': 32,
 'shells': 33,
 'mulberries': 34,
 'green jelly': 35,
 'chestnuts': 36,
 'Oreos': 37,
 'Vitamins': 38,
 'black tea extract': 39,
 'unknown vegetables': 40,
 'fried eggs': 41,
 'salad greens': 42,
 'scallop flavoring': 43,
 'turnips': 44,
 'Whipped cream': 45,
 'sundried tomatoes': 46,
 'cooking oil': 47,
 'pounded yam': 48,
 'aloe vera': 49,
 'vegetabl

### Extract ingredients and recipes

In [233]:
df_recipies["ingredients"].tolist()

['chicken,breading,oil',
 'noodles,beef,basil,lime,green onions,chili',
 'dumplings,chili oil,soy sauce',
 'Bananas',
 'noodles,chicken,vegetables,sauce',
 'shrimp,noodles,garlic,green onions,chili sauce',
 'beef,vegetables,rice,soup',
 'dried noodles',
 'noodles,broth,meat,vegetables',
 'oranges',
 'pumpkin,rice,meat,spices',
 'eggs,tomatoes,onions,bread',
 'baby corn,sauce,lettuce',
 'crab,spices,vegetables,sauce',
 'bread,meat,lettuce,tomato,cucumber',
 'corn,eggs,blueberries',
 'chicken,seasoning,sauce',
 'noodles,ground meat,seaweed,carrots,bell peppers',
 'fried chicken,salad greens,soup broth,lemon,dipping sauce',
 'meat,sauce,green onions',
 'chicken wings,fried ribs,spicy sauce,green peppers',
 'bread,cheese,egg,sausage',
 'chicken,green peppers,red peppers,spices',
 'beef,shrimp,tofu,vegetables,noodles',
 'fruit,sugar,glaze',
 'snails,garlic,herbs,dipping sauce',
 'milk,sugar,stabilizers',
 'noodles,meat,broth,vegetables',
 'fruit,sugar,cream',
 'noodles,green onions,bean spr

In [241]:
recipe_ingredients = df_recipies["ingredients"].tolist()
recipies = df_recipies["dish_name"]

#### Test the extraction

In [242]:
recipies[10]

'Stuffed Pumpkin'

In [243]:
recipe_ingredients[10]

'pumpkin,rice,meat,spices'

### Embedding model for ingredients

In [244]:
class RecipeEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, ingredient_indices):
        # Pad variable-length ingredient lists
        padded = nn.utils.rnn.pad_sequence(
            [torch.tensor(x) for x in ingredient_indices],
            batch_first=True
        )
        mask = (padded != 0)  # padding mask
        embeds = self.embedding(padded)

        # Mean pooling over ingredients
        masked_embeds = embeds * mask.unsqueeze(-1)
        recipe_embeds = masked_embeds.sum(1) / mask.sum(1, keepdim=True)

        # normalize for cosine similarity
        recipe_embeds = F.normalize(recipe_embeds, dim=-1)
        return recipe_embeds

### Contrastive loss definition

In [245]:
def contrastive_loss(batch_one, batch_two, temperature=0.2):
    batch_size = batch_one.size(0)
    z = torch.cat([batch_one, batch_two], dim=0)  # (2N, d)

    # Cosine similarity matrix
    sim = torch.matmul(z, z.T) / temperature
    sim.fill_diagonal_(-9e15)  # mask self-similarity

    # Positive pairs: i-th in first half with i-th in second half
    labels = torch.cat([torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)], dim=0).to(z.device)

    loss = F.cross_entropy(sim, labels)
    return loss

### Definition of the train run

In [246]:
def train(model, recipes, epochs=1, batch_size=4, lr=1e-3, device="cpu"):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(epochs):
        random.shuffle(recipes)

        total_loss = 0
        for i in tqdm.tqdm(range(0, len(recipes), batch_size)):
            batch = recipes[i:i+batch_size]

            # Create augmented views (drop random ingredients for augmentation)
            def augment(recipe):
                if len(recipe) > 1:
                    keep = random.sample(recipe, k=max(1, len(recipe)-1))
                    return keep
                return recipe

            batch_i = [augment(r) for r in batch]
            batch_j = [augment(r) for r in batch]

            z_i = model(batch_i).to(device)
            z_j = model(batch_j).to(device)

            loss = contrastive_loss(z_i, z_j)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(recipes):.4f}")

### Execute the train run

In [247]:
vocab_size = len(unique_ingredients_list)
encoding_dim = 64

In [248]:
train_recipies = list()
for recipe_ingredient in recipe_ingredients:
    print(recipe_ingredient.split(","))
    try:
        train_recipies.append([ingredient2idx[i] for i in recipe_ingredient.split(',')])
    except KeyError:
        pass

['chicken', 'breading', 'oil']
['noodles', 'beef', 'basil', 'lime', 'green onions', 'chili']
['dumplings', 'chili oil', 'soy sauce']
['Bananas']
['noodles', 'chicken', 'vegetables', 'sauce']
['shrimp', 'noodles', 'garlic', 'green onions', 'chili sauce']
['beef', 'vegetables', 'rice', 'soup']
['dried noodles']
['noodles', 'broth', 'meat', 'vegetables']
['oranges']
['pumpkin', 'rice', 'meat', 'spices']
['eggs', 'tomatoes', 'onions', 'bread']
['baby corn', 'sauce', 'lettuce']
['crab', 'spices', 'vegetables', 'sauce']
['bread', 'meat', 'lettuce', 'tomato', 'cucumber']
['corn', 'eggs', 'blueberries']
['chicken', 'seasoning', 'sauce']
['noodles', 'ground meat', 'seaweed', 'carrots', 'bell peppers']
['fried chicken', 'salad greens', 'soup broth', 'lemon', 'dipping sauce']
['meat', 'sauce', 'green onions']
['chicken wings', 'fried ribs', 'spicy sauce', 'green peppers']
['bread', 'cheese', 'egg', 'sausage']
['chicken', 'green peppers', 'red peppers', 'spices']
['beef', 'shrimp', 'tofu', 'vegeta

In [249]:
model = RecipeEmbeddingModel(vocab_size, encoding_dim)

In [250]:
train(model, train_recipies, epochs=15, batch_size=8)

100%|██████████| 12495/12495 [00:08<00:00, 1495.39it/s]


Epoch 1, Loss: 0.0619


100%|██████████| 12495/12495 [00:08<00:00, 1490.87it/s]


Epoch 2, Loss: 0.0550


100%|██████████| 12495/12495 [00:08<00:00, 1433.22it/s]


Epoch 3, Loss: 0.0529


100%|██████████| 12495/12495 [00:10<00:00, 1167.98it/s]


Epoch 4, Loss: 0.0515


100%|██████████| 12495/12495 [00:08<00:00, 1497.86it/s]


Epoch 5, Loss: 0.0507


100%|██████████| 12495/12495 [00:08<00:00, 1502.03it/s]


Epoch 6, Loss: 0.0504


100%|██████████| 12495/12495 [00:08<00:00, 1498.63it/s]


Epoch 7, Loss: 0.0495


100%|██████████| 12495/12495 [00:08<00:00, 1502.46it/s]


Epoch 8, Loss: 0.0493


100%|██████████| 12495/12495 [00:08<00:00, 1505.57it/s]


Epoch 9, Loss: 0.0488


100%|██████████| 12495/12495 [00:08<00:00, 1501.45it/s]


Epoch 10, Loss: 0.0488


100%|██████████| 12495/12495 [00:08<00:00, 1502.34it/s]


Epoch 11, Loss: 0.0483


100%|██████████| 12495/12495 [00:08<00:00, 1498.22it/s]


Epoch 12, Loss: 0.0482


100%|██████████| 12495/12495 [00:08<00:00, 1478.95it/s]


Epoch 13, Loss: 0.0480


100%|██████████| 12495/12495 [00:08<00:00, 1455.63it/s]


Epoch 14, Loss: 0.0477


100%|██████████| 12495/12495 [00:08<00:00, 1453.27it/s]

Epoch 15, Loss: 0.0478





### Test the model

In [251]:
print(recipe_ingredients[1])
print(recipe_ingredients[2])

noodles,beef,basil,lime,green onions,chili
dumplings,chili oil,soy sauce


In [252]:
print(recipies[1])
print(recipies[2])

Pho
Pan-fried Dumplings


In [253]:
test_query_recipe_1 = train_recipies[1].copy()
test_query_recipe_2 = train_recipies[2].copy()

In [265]:
model.eval()
with torch.no_grad():
    test_query_recipe_1_emb = model([[ingredient2idx["meat"], ingredient2idx["spices"], ingredient2idx["sauce"]]])
    test_query_recipe_2_emb = model([[ingredient2idx["shrimp"], ingredient2idx["spices"], ingredient2idx["sauce"]]])

In [266]:
F.cosine_similarity(test_query_recipe_1_emb, test_query_recipe_2_emb)

tensor([0.6097])

### Save the model

In [267]:
torch.save(model.state_dict(), "./../models/recipe_embedding.pt")