In [1]:
import pandas as pd

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from PIL import Image, ImageFile
from tqdm.notebook import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_cpu = torch.device("cpu")

SAMPLE_SIZE = 200000

In [2]:
device

device(type='cuda', index=0)

In [3]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Load data

In [4]:
df_base = pd.read_csv("../../data/sparkrecipes_base.csv")
df_images = pd.read_csv("../../data/sparkrecipes_images.csv")

In [5]:
df = df_base.merge(df_images, on="recipe_id", how="inner").sample(SAMPLE_SIZE)

In [6]:
df

Unnamed: 0,recipe_id,title,total_calories,url,servings,image_path
1176056,291788,Low Fat Low Calorie Cool 'n Easy Pie!,119.2,https://recipes.sparkpeople.com/recipe-detail....,8.0,../../data/images/291788/000011
748415,186129,Roasted Cauliflower,77.3,https://recipes.sparkpeople.com/recipe-detail....,4.0,../../data/images/186129/000014
348160,89647,Pilao,183.9,https://recipes.sparkpeople.com/recipe-detail....,8.0,../../data/images/89647/000012
237945,63258,Healthy Chicken and Pasta,347.0,https://recipes.sparkpeople.com/recipe-detail....,1.0,../../data/images/63258/000006
1377783,344852,Strawberry Spinach Salad,164.8,https://recipes.sparkpeople.com/recipe-detail....,1.0,../../data/images/344852/000011
...,...,...,...,...,...,...
63562,15002,Quick and Healthy Vegetable Beef Soup,417.9,https://recipes.sparkpeople.com/recipe-detail....,4.0,../../data/images/15002/000015
699107,174431,Oyako bowl,340.3,https://recipes.sparkpeople.com/recipe-detail....,4.0,../../data/images/174431/000008
219419,59168,Lite Caesar Chicken,297.0,https://recipes.sparkpeople.com/recipe-detail....,2.0,../../data/images/59168/000005
578654,144402,Easy Dinner Rolls,118.7,https://recipes.sparkpeople.com/recipe-detail....,12.0,../../data/images/144402/000002


## Create image embeddings

In [7]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [8]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform):
        self.images = df["image_path"].reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images.iloc[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

In [9]:
squeezenet = models.squeezenet1_0(pretrained=True)
set_parameter_requires_grad(squeezenet, True)

In [10]:
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super(FeatureExtractor, self).__init__()
        self.features = model.features
        self.avg_pool2d = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool2d(x)
        x = x.squeeze()
        return x

In [11]:
feature_extractor = FeatureExtractor(squeezenet).to(device)

In [12]:
IMAGE_SIZE = 224

data_transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [13]:
dataset = ImageDataset(df, data_transforms)

In [14]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=100, shuffle=False, num_workers=3
)

In [15]:
embeddings = []
for item in tqdm(dataloader):
    item = item.to(device)
    embeddings.append(feature_extractor(item))
    del item
    torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

  "Palette images with Transparency expressed in bytes should be "
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  "Palette images with Transparency expressed in bytes should be "
  "Palette images with Transparency expressed in bytes should be "
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag 




### Persist results

In [18]:
# cpu_embeddings = [e.to(device_cpu) for e in embeddings]

In [16]:
df_embeddings = df.join(pd.DataFrame(torch.cat(embeddings).tolist(), index=df.index)).drop(["image_path"], axis=1)
df_embeddings

Unnamed: 0,recipe_id,title,total_calories,url,servings,0,1,2,3,4,...,502,503,504,505,506,507,508,509,510,511
1176056,291788,Low Fat Low Calorie Cool 'n Easy Pie!,119.2,https://recipes.sparkpeople.com/recipe-detail....,8.0,1.657542,3.306842,0.150144,0.716325,1.412788,...,0.851129,0.784485,0.000000,0.000000,0.416459,0.553366,1.400850,0.046361,2.531902,0.652344
748415,186129,Roasted Cauliflower,77.3,https://recipes.sparkpeople.com/recipe-detail....,4.0,0.131091,0.756927,0.000000,0.000000,0.000000,...,0.207664,0.217347,0.000000,0.007426,0.236245,0.743847,4.196649,2.199195,0.461011,0.003973
348160,89647,Pilao,183.9,https://recipes.sparkpeople.com/recipe-detail....,8.0,0.256996,0.506713,0.000000,0.000000,0.295367,...,0.104268,0.289502,1.122088,0.000000,6.332223,0.903992,6.343424,0.200299,6.157534,3.011780
237945,63258,Healthy Chicken and Pasta,347.0,https://recipes.sparkpeople.com/recipe-detail....,1.0,0.481375,0.975534,0.176997,0.003095,0.000000,...,0.540625,1.119777,0.385062,0.000000,3.115180,3.962124,5.996787,0.086081,4.714542,3.388931
1377783,344852,Strawberry Spinach Salad,164.8,https://recipes.sparkpeople.com/recipe-detail....,1.0,0.872174,1.515256,6.107618,0.000000,0.000000,...,2.801720,0.000000,0.204316,0.000000,1.837780,0.611731,3.276474,0.122105,2.763169,1.290578
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
63562,15002,Quick and Healthy Vegetable Beef Soup,417.9,https://recipes.sparkpeople.com/recipe-detail....,4.0,0.096497,0.538639,0.159070,1.035311,0.198755,...,0.027319,0.812292,0.593234,0.102146,3.986886,0.929327,2.781661,0.265407,6.964776,5.216623
699107,174431,Oyako bowl,340.3,https://recipes.sparkpeople.com/recipe-detail....,4.0,1.269787,3.974391,1.534365,0.107423,0.982815,...,1.824443,0.025285,0.000000,0.000000,1.270573,0.854426,2.218760,0.023541,1.541791,5.645164
219419,59168,Lite Caesar Chicken,297.0,https://recipes.sparkpeople.com/recipe-detail....,2.0,0.072812,2.873926,0.100634,0.573897,0.000000,...,0.092382,0.000000,4.007784,0.000000,0.525843,0.035090,0.373691,1.361642,0.266344,0.062209
578654,144402,Easy Dinner Rolls,118.7,https://recipes.sparkpeople.com/recipe-detail....,12.0,0.104968,0.239626,2.948921,1.556745,0.520375,...,0.654853,0.171278,0.000000,0.000000,2.507907,0.392842,0.200156,0.000000,0.095376,1.422864


In [17]:
df_out = df_embeddings[["recipe_id"] + list(df_embeddings.columns[5:])].reset_index(drop=True)
df_out.columns = ["recipe_id"] + [f"f_{i}" for i in df_embeddings.columns[5:]]
df_out

Unnamed: 0,recipe_id,f_0,f_1,f_2,f_3,f_4,f_5,f_6,f_7,f_8,...,f_502,f_503,f_504,f_505,f_506,f_507,f_508,f_509,f_510,f_511
0,291788,1.657542,3.306842,0.150144,0.716325,1.412788,0.000000,0.612706,0.156094,1.111069,...,0.851129,0.784485,0.000000,0.000000,0.416459,0.553366,1.400850,0.046361,2.531902,0.652344
1,186129,0.131091,0.756927,0.000000,0.000000,0.000000,0.042949,0.083060,0.087035,0.063373,...,0.207664,0.217347,0.000000,0.007426,0.236245,0.743847,4.196649,2.199195,0.461011,0.003973
2,89647,0.256996,0.506713,0.000000,0.000000,0.295367,0.000000,2.670163,0.070306,0.785169,...,0.104268,0.289502,1.122088,0.000000,6.332223,0.903992,6.343424,0.200299,6.157534,3.011780
3,63258,0.481375,0.975534,0.176997,0.003095,0.000000,0.389275,0.459856,0.805673,0.870744,...,0.540625,1.119777,0.385062,0.000000,3.115180,3.962124,5.996787,0.086081,4.714542,3.388931
4,344852,0.872174,1.515256,6.107618,0.000000,0.000000,0.544032,0.224283,1.092227,5.745899,...,2.801720,0.000000,0.204316,0.000000,1.837780,0.611731,3.276474,0.122105,2.763169,1.290578
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
199995,15002,0.096497,0.538639,0.159070,1.035311,0.198755,0.254970,0.194375,0.614713,1.017653,...,0.027319,0.812292,0.593234,0.102146,3.986886,0.929327,2.781661,0.265407,6.964776,5.216623
199996,174431,1.269787,3.974391,1.534365,0.107423,0.982815,0.000000,1.469487,0.050964,0.241303,...,1.824443,0.025285,0.000000,0.000000,1.270573,0.854426,2.218760,0.023541,1.541791,5.645164
199997,59168,0.072812,2.873926,0.100634,0.573897,0.000000,0.023103,2.683917,0.042412,0.825761,...,0.092382,0.000000,4.007784,0.000000,0.525843,0.035090,0.373691,1.361642,0.266344,0.062209
199998,144402,0.104968,0.239626,2.948921,1.556745,0.520375,1.501898,0.000000,0.237335,1.838255,...,0.654853,0.171278,0.000000,0.000000,2.507907,0.392842,0.200156,0.000000,0.095376,1.422864


In [18]:
df_out.to_csv("../../data/sparkrecipes_embeddings.csv" ,index=False)