In [1]:
%load_ext google.cloud.bigquery

In [2]:
import pandas as pd
import csv

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from PIL import Image, ImageFile
from tqdm.notebook import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")

In [3]:
device

device(type='cuda', index=0)

In [4]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Load Data

In [5]:
%%bigquery df --project zenscr-seefood-dev

SELECT title, image_path, total_calories
FROM `zenscr-seefood-dev.sparkrecipes.base_filtered`
INNER JOIN `zenscr-seefood-dev.sparkrecipes.image_path`
USING (recipe_id)

## Create Image Embeddings

In [6]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [7]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform):
        self.images = df["image_path"].reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images.iloc[idx]
        with Image.open(img_path) as f:
            image = f.convert("RGB")
        image = self.transform(image)
        return image

In [8]:
mobile_net = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(mobile_net, True)

In [9]:
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super(FeatureExtractor, self).__init__()
        self.features = model.features
        self.avg_pool2d = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool2d(x)
        x = x.squeeze()
        return x

In [10]:
feature_extractor = FeatureExtractor(mobile_net).to(device)

In [11]:
IMAGE_SIZE = 224

data_transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [12]:
dataset = ImageDataset(df, data_transforms)

In [13]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=128, shuffle=False, num_workers=8
)

In [None]:
embeddings = []
for item in tqdm(dataloader):
    item = item.to(device)
    embeddings.append(feature_extractor(item).to(cpu))
    del item

HBox(children=(FloatProgress(value=0.0, max=10248.0), HTML(value='')))

### Persist results

In [None]:
df_embeddings = df.join(pd.DataFrame(torch.cat(embeddings).tolist(), index=df.index)).drop(["image_path"], axis=1)
df_embeddings

In [None]:
df_out = df_embeddings[["recipe_id"] + list(df_embeddings.columns[5:])].reset_index(drop=True)
df_out.columns = ["recipe_id"] + [f"f_{i}" for i in df_embeddings.columns[5:]]
df_out

In [None]:
df_out.to_csv("../../data/sparkrecipes_embeddings.csv" ,index=False)