# Conditional variational autoencoder
Variational autoencoder for tabular data, oriented upon: https://lschmiddey.github.io/fastpages_/2021/03/14/tabular-data-variational-autoencoder.html 

Adopted with one hot encoding for tabular data
## Load Json Database of recipes

In [1]:
import json
import sys
sys.path.append('../..')

from utils import validateJson

jsonData = json.load(open('recipes_valid.json'))
#validateJson.validateRecimeJson(jsonData)

print(jsonData.keys())

dict_keys(['title', 'ingredients', 'instructions'])


In [26]:
import pickle
import pandas as pd

with open('recipes_valid.pkl', 'rb') as f:
    pklData = pd.DataFrame(pickle.load(f))

print(pklData.keys())

Index(['title', 'ingredients', 'instructions'], dtype='object')


## Convert list of ingredients to pandas dataframe and one hot encode the dataframe.

In [50]:
data = []

for _, item in pklData.iterrows():
    rowData = {}
    for index, ingredient in item['ingredients'].iterrows():
        for key in ingredient.keys():
            rowData[key+str(index)] = ingredient[key]
    data.append(rowData)

frame = pd.DataFrame(data)


In [51]:
from sklearn.preprocessing import OneHotEncoder


dropColumns = []
for column in frame.columns:
    if 'amount' in column:
        dropColumns.append(column)
frameStripped = frame.drop(columns=dropColumns)

enc = OneHotEncoder()
enc.fit(frameStripped)
frameStrippedEncodedSki = pd.DataFrame(data=enc.transform(frameStripped).toarray(), columns=enc.get_feature_names(frameStripped.columns))

frameEncodedSki = pd.concat([frame[dropColumns], frameStrippedEncodedSki], axis=1)
frameEncodedSki

## VAE
### Setup

In [4]:
# Import pytorch dependencies
import torch
from torch import optim
from torch.utils.data import DataLoader

# Import additional libraries
import pandas as pd
import numpy as np

# Import custom autoencoder
from cvae import VariationalAutoEncoder

# Import custom helper functions
from networkUtils import DataBuilder, CustomLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Setup Datasets + Hyperparameters

In [6]:
batch_size = 1024
# Train/Testdataset split is defined in the DataBuilder
traindata_set=DataBuilder(frameEncodedSki, train=True)
testdata_set=DataBuilder(frameEncodedSki, train=False)
# Definition of batches
trainloader=DataLoader(dataset=traindata_set,batch_size=batch_size)
testloader=DataLoader(dataset=testdata_set,batch_size=batch_size)

In [7]:
D_in = testdata_set.x.shape[1]
H = 1024
H2 = 128
latent_dim = 32
model = VariationalAutoEncoder(D_in, H, H2, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_mse = CustomLoss()

### Train Model

In [8]:
epochs = 50
log_interval = 5
val_losses = []
train_losses = []
test_losses = []

In [9]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(trainloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_mse(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    if epoch % log_interval == 0:        
        print('====> Epoch: {} Average training loss: {:.4f}'.format(
            epoch, train_loss / len(trainloader.dataset)))
        train_losses.append(train_loss / len(trainloader.dataset))

In [10]:
def test(epoch):
    with torch.no_grad():
        test_loss = 0
        for batch_idx, data in enumerate(testloader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_mse(recon_batch, data, mu, logvar)
            test_loss += loss.item()
            if epoch % log_interval == 0:        
                print('====> Epoch: {} Average test loss: {:.4f}'.format(
                    epoch, test_loss / len(testloader.dataset)))
            test_losses.append(test_loss / len(testloader.dataset))

In [11]:
for epoch in range(1,epochs+1):
    train(epoch)
    test(epoch)

====> Epoch: 5 Average training loss: 1395.9323
====> Epoch: 5 Average test loss: 930.0054
====> Epoch: 5 Average test loss: 1150.5028
====> Epoch: 10 Average training loss: 1252.2459
====> Epoch: 10 Average test loss: 885.2545
====> Epoch: 10 Average test loss: 1097.0679
====> Epoch: 15 Average training loss: 1167.4181
====> Epoch: 15 Average test loss: 849.6265
====> Epoch: 15 Average test loss: 1057.2856
====> Epoch: 20 Average training loss: 1081.0109
====> Epoch: 20 Average test loss: 839.6844
====> Epoch: 20 Average test loss: 1040.9513
====> Epoch: 25 Average training loss: 1025.6782
====> Epoch: 25 Average test loss: 824.6608
====> Epoch: 25 Average test loss: 1019.2592
====> Epoch: 30 Average training loss: 975.8043
====> Epoch: 30 Average test loss: 800.5249
====> Epoch: 30 Average test loss: 993.0224
====> Epoch: 35 Average training loss: 931.6133
====> Epoch: 35 Average test loss: 790.6359
====> Epoch: 35 Average test loss: 974.9039
====> Epoch: 40 Average training loss: 87

In [12]:
with torch.no_grad():
    for batch_idx, data in enumerate(testloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)

In [13]:
scaler = trainloader.dataset.standardizer
recon_row = scaler.inverse_transform(recon_batch[0].cpu().numpy())
real_row = scaler.inverse_transform(testloader.dataset.x[0].cpu().numpy())

In [14]:
cols = frameEncodedSki.columns
df = pd.DataFrame(np.stack((recon_row, real_row)), columns = cols)
df

Unnamed: 0,amount1,amount2,amount3,amount4,amount5,amount6,amount7,amount8,amount9,amount10,...,ingredient17_nan,unit18_,unit18_nan,unit19_tablespoon,unit19_nan,ingredient18_Belgian endive,ingredient18_salt,ingredient18_nan,ingredient19_fresh chives,ingredient19_nan
0,1.479655,3.22903,6.531128,9.989347,10.822272,8.01971,3.073665,29.038771,29.278084,18.770477,...,0.534768,0.3762447,0.622974,0.2532713,0.743393,0.2539843,0.2536113,0.62861,0.2528057,0.743451
1,1.0,13.0,14.0,0.999999,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,...,1.0,-1.596553e-10,1.0,-7.982765e-11,1.0,-7.982765e-11,-7.982765e-11,1.0,-7.982765e-11,1.0


### Draw random samples form latent space and generate new data

In [15]:
sigma = torch.exp(logvar/2)
# sample z from q
no_samples = 20
q = torch.distributions.Normal(mu.mean(axis=0), sigma.mean(axis=0))
z = q.rsample(sample_shape=torch.Size([no_samples]))

In [16]:
with torch.no_grad():
    pred = model.decode(z).cpu().numpy()

fake_data = scaler.inverse_transform(pred)
df_fake = pd.DataFrame(fake_data, columns = cols)
df_fake.head(10)

Unnamed: 0,amount1,amount2,amount3,amount4,amount5,amount6,amount7,amount8,amount9,amount10,...,ingredient17_nan,unit18_,unit18_nan,unit19_tablespoon,unit19_nan,ingredient18_Belgian endive,ingredient18_salt,ingredient18_nan,ingredient19_fresh chives,ingredient19_nan
0,6.829062,2.511984,5.961226,-19.159744,-11.163843,-7.973224,-34.141468,-28.99958,-46.243988,-75.130341,...,0.896284,0.176117,0.872221,0.119035,0.892649,0.117437,0.084045,0.844058,0.101687,0.887356
1,-0.558599,0.004498,7.685414,-0.943294,-1.023584,-8.774481,-20.269455,-2.331251,4.012699,-31.717709,...,0.917232,0.113151,0.977117,0.018057,0.957651,0.082628,0.027919,0.992583,0.023955,0.970259
2,4.958715,6.950881,2.380575,12.473952,24.097204,7.244372,10.182802,23.236761,-10.380653,-5.562351,...,1.062502,-0.121497,1.063607,-0.029941,1.022223,-0.101936,-0.061089,1.063715,-0.064873,1.097188
3,0.048239,-2.632349,0.665736,2.700111,-15.033566,-9.242425,-20.194672,-59.454082,-94.903099,-83.537376,...,1.084527,-0.152754,1.165295,-0.080572,1.091375,-0.089808,-0.131149,1.134534,-0.046177,1.038506
4,4.430035,5.309525,-0.867253,-65.163963,-75.567078,-89.155312,-114.084633,-154.677551,-149.751434,-131.107193,...,0.996318,-0.073089,1.092498,-0.029264,1.053509,-0.016746,-0.039226,1.093219,-0.012651,1.00568
5,1.907051,0.442167,-0.360131,-4.02731,15.836774,-9.919238,2.324528,-29.0149,-64.956993,-91.406204,...,1.021653,-0.017249,1.059388,-0.022352,1.036828,-0.006306,-0.054451,1.055804,-0.003711,1.019902
6,12.079546,6.263891,1.714478,3.30294,-12.109895,-18.353716,-51.311043,-86.895561,-80.77446,-105.111801,...,0.937842,0.057602,0.935723,0.0421,0.97234,0.015353,0.030166,0.964829,-0.001166,0.981929
7,-3.695374,-6.064839,6.136834,7.523921,8.935935,26.994719,17.182175,18.538019,62.913696,76.115662,...,1.110233,-0.082117,1.12583,-0.024521,1.035009,-0.007054,0.021099,1.011229,-0.07612,0.963093
8,-3.737212,-2.333614,6.303004,5.874744,38.085106,19.858004,20.786072,8.818888,-29.021166,-56.431808,...,0.92273,-0.068329,1.012441,-0.062005,1.03605,-0.044921,-0.089302,1.036312,-0.037275,1.045389
9,0.182713,-1.265313,9.031153,-5.851704,8.173281,16.656931,-3.619381,7.066553,-6.010429,-19.974388,...,0.162414,0.568153,0.404951,0.442963,0.565243,0.409952,0.36342,0.381841,0.436103,0.585767


In [17]:
df_fake_stripped = df_fake.drop(columns=dropColumns)
df_fake_stripped_decoded = pd.DataFrame(data=enc.inverse_transform(df_fake_stripped), columns=frameStripped.columns)
df_fake_decoded = pd.concat([df_fake[dropColumns], df_fake_stripped_decoded], axis=1)
df_fake_decoded

Unnamed: 0,amount1,amount2,amount3,amount4,amount5,amount6,amount7,amount8,amount9,amount10,...,unit15,unit16,unit17,ingredient15,ingredient16,ingredient17,unit18,unit19,ingredient18,ingredient19
0,6.829062,2.511984,5.961226,-19.159744,-11.163843,-7.973224,-34.141468,-28.99958,-46.243988,-75.130341,...,,,,,,,,,,
1,-0.558599,0.004498,7.685414,-0.943294,-1.023584,-8.774481,-20.269455,-2.331251,4.012699,-31.717709,...,tablespoon,,,oil,salt,,,,,
2,4.958715,6.950881,2.380575,12.473952,24.097204,7.244372,10.182802,23.236761,-10.380653,-5.562351,...,,,,,,,,,,
3,0.048239,-2.632349,0.665736,2.700111,-15.033566,-9.242425,-20.194672,-59.454082,-94.903099,-83.537376,...,,,,,,,,,,
4,4.430035,5.309525,-0.867253,-65.163963,-75.567078,-89.155312,-114.084633,-154.677551,-149.751434,-131.107193,...,,,,fresh ground black pepper,,,,,,
5,1.907051,0.442167,-0.360131,-4.02731,15.836774,-9.919238,2.324528,-29.0149,-64.956993,-91.406204,...,,,,,,,,,,
6,12.079546,6.263891,1.714478,3.30294,-12.109895,-18.353716,-51.311043,-86.895561,-80.77446,-105.111801,...,,,,,,,,,,
7,-3.695374,-6.064839,6.136834,7.523921,8.935935,26.994719,17.182175,18.538019,62.913696,76.115662,...,,,,,,,,,,
8,-3.737212,-2.333614,6.303004,5.874744,38.085106,19.858004,20.786072,8.818888,-29.021166,-56.431808,...,,,,hot dogs,,,,,,
9,0.182713,-1.265313,9.031153,-5.851704,8.173281,16.656931,-3.619381,7.066553,-6.010429,-19.974388,...,,tablespoon,,hot dogs,vanilla extract,salt,,,Belgian endive,


In [18]:
class Ingredient:
    def __init__(self, amount, unit, ingredient) -> None:
        self.amount = amount
        self.unit = unit
        self.ingredient = ingredient

    def __repr__(self) -> str:
        return "\nAmount: " + str(self.amount) + "\n Unit: " + str(self.unit) + "\n Ingredient: " + str(self.ingredient)

recipes = []
lenIngredients = int(len(df_fake_decoded.columns)/3)
for value in df_fake_decoded.values:
    ingredients = []
    for index in range(0,lenIngredients):
        frame = []
        frame.append(value[index])
        frame.append(value[index+lenIngredients])
        frame.append(value[index+2*lenIngredients])
        ingredients.append(frame)
    recipes.append(ingredients)

pd.DataFrame(recipes)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18
0,"[6.829062461853027, cup, avocado]","[2.511983871459961, tablespoon, nan]","[5.961226463317871, , nan]","[-19.159744262695312, teaspoon, nan]","[-11.163843154907227, teaspoon, teaspoon]","[-7.97322416305542, , nan]","[-34.1414680480957, cup, pepper]","[-28.99958038330078, gram, nan]","[-46.243988037109375, nan, salt]","[-75.13034057617188, fresh chives, nan]","[-72.3921890258789, crushed pineapple, nan]","[-81.75389862060547, orange liqueur, nan]","[-79.66908264160156, white cranberry juice, nan]","[-72.53083038330078, egg whites, nan]","[-84.18598175048828, red wine vinegar, nan]","[-80.61902618408203, balsamic vinegar, nan]","[-85.18389129638672, Kraft Grated Parmesan Che...","[-87.16886138916016, mustard powder, nan]","[-91.02033996582031, nan, nan]"
1,"[-0.5585987567901611, teaspoon, capers]","[0.004498175345361233, , ]","[7.685413837432861, gram, dried currants]","[-0.9432942867279053, , ]","[-1.0235835313796997, cup, teaspoon]","[-8.774480819702148, cubic centimetre, ]","[-20.269454956054688, cup, nutmeg]","[-2.331251382827759, ounce, dried currants]","[4.012699127197266, , salt]","[-31.717708587646484, cake flour, tablespoon]","[-26.628149032592773, cold water, ]","[-24.45585823059082, unsalted butter, nan]","[-0.8983702063560486, cottage cheese, oil]","[-22.13079261779785, fresh ground black pepper...","[-18.761037826538086, garlic cloves, nan]","[-22.9213809967041, cauliflower, nan]","[-103.84403228759766, white sugar, nan]","[-96.94766998291016, ground cumin, nan]","[-91.6397705078125, , nan]"
2,"[4.958714962005615, teaspoon, vanilla extract]","[6.950880527496338, , tablespoon]","[2.380574941635132, gram, lemon]","[12.473952293395996, cup, nan]","[24.097204208374023, tablespoon, tablespoon]","[7.2443718910217285, cubic centimetre, ]","[10.182802200317383, teaspoon, pepper]","[23.23676109313965, gram, salt and black pepper]","[-10.380653381347656, , morton lite salt]","[-5.562350749969482, acorn squash, nan]","[-25.291303634643555, rolled oats, nan]","[-64.44786071777344, jalapeno, nan]","[-42.170570373535156, chocolate chips, nan]","[-42.29719161987305, garlic cloves, nan]","[-100.59165954589844, extra virgin olive oil, ...","[-95.86457061767578, vanilla, nan]","[-95.35540008544922, cumin, nan]","[-103.20608520507812, water, nan]","[-107.07929992675781, cubic centimetre, nan]"
3,"[0.04823925718665123, gram, cashew nuts]","[-2.6323494911193848, , nan]","[0.665736198425293, cup, mushroom soup]","[2.700110673904419, gram, nan]","[-15.033565521240234, teaspoon, nan]","[-9.242424964904785, teaspoon, nan]","[-20.194671630859375, , pepper]","[-59.45408248901367, nan, cinnamon]","[-94.9030990600586, nan, lemon]","[-83.5373764038086, shredded coconut, nan]","[-89.45399475097656, brandy, nan]","[-95.18731689453125, tomatoes, nan]","[-105.66253662109375, lemon juice, nan]","[-103.00941467285156, onions, nan]","[-103.3084716796875, ground nutmeg, nan]","[-105.01751708984375, garlic, nan]","[-108.19489288330078, salt and pepper, nan]","[-108.66646575927734, fennel bulb, nan]","[-107.1209945678711, nan, nan]"
4,"[4.430034637451172, gram, nan]","[5.309525012969971, , nan]","[-0.8672531843185425, , pepper]","[-65.1639633178711, nan, nan]","[-75.56707763671875, nan, nan]","[-89.15531158447266, nan, nan]","[-114.08463287353516, nan, nan]","[-154.67755126953125, nan, mixed vegetables]","[-149.75143432617188, nan, nan]","[-131.10719299316406, cornstarch, nan]","[-123.24909210205078, russet potatoes, nan]","[-105.03746032714844, ketchup, nan]","[-115.44496154785156, baking soda, fresh groun...","[-106.54219818115234, onions, nan]","[-103.22066497802734, nan, nan]","[-109.89949035644531, nan, nan]","[-102.9728775024414, nan, nan]","[-107.24531555175781, dry dill weed, nan]","[-104.44977569580078, nan, nan]"
5,"[1.9070508480072021, gram, garbanzo beans]","[0.4421665072441101, , nan]","[-0.3601309061050415, tablespoon, rome apples]","[-4.027310371398926, gram, nan]","[15.836773872375488, ounce, teaspoon]","[-9.919238090515137, cup, nan]","[2.3245275020599365, tablespoon, pepper]","[-29.01490020751953, gram, hot pepper sauce]","[-64.95699310302734, nan, lemon]","[-91.40620422363281, onions, nan]","[-121.114013671875, pork tenderloin, nan]","[-115.3040771484375, fresh mushrooms, nan]","[-123.6551284790039, white sesame seeds, nan]","[-118.3110580444336, pimentos, nan]","[-104.32015228271484, lemon, nan]","[-110.54730987548828, kidney beans, nan]","[-93.82437896728516, ice cubes, nan]","[-103.61910247802734, vegetable broth, nan]","[-101.53034210205078, nan, nan]"
6,"[12.079545974731445, teaspoon, avocado]","[6.263890743255615, tablespoon, nan]","[1.7144775390625, , lemon]","[3.3029396533966064, drop, nan]","[-12.109894752502441, , nan]","[-18.353715896606445, quart, nan]","[-51.31104278564453, cup, oats]","[-86.89556121826172, gram, red wine vinegar]","[-80.77445983886719, nan, nan]","[-105.11180114746094, egg, nan]","[-82.15573120117188, black beans, nan]","[-92.66659545898438, boiling water, nan]","[-111.54261016845703, white sesame seeds, nan]","[-101.31503295898438, orange rind, nan]","[-110.2247085571289, dry white wine, nan]","[-97.49958801269531, reduced - sodium chicken ...","[-96.73670196533203, unsalted butter, nan]","[-98.3888168334961, crushed pineapple in juice...","[-93.3880844116211, cubic centimetre, nan]"
7,"[-3.6953744888305664, , oil]","[-6.064838886260986, cup, teaspoon]","[6.136833667755127, gram, dried currants]","[7.523921489715576, , ]","[8.935935020446777, , tablespoon]","[26.994718551635742, cubic centimetre, cup]","[17.182174682617188, tablespoon, mayonnaise]","[18.53801918029785, ounce, dried currants]","[62.9136962890625, teaspoon, morton lite salt]","[76.11566162109375, extra virgin olive oil, nan]","[64.798095703125, smoked bacon, nan]","[63.82243728637695, frozen lemonade concentrat...","[-21.496286392211914, vegetable oil, nan]","[3.3853635787963867, red bell peppers, nan]","[-101.85506439208984, dried oregano, nan]","[-113.97174072265625, Sprite, nan]","[-106.0098876953125, cumin, nan]","[-98.95318603515625, chili powder, nan]","[-107.46267700195312, pound-mass, nan]"
8,"[-3.7372121810913086, , oil]","[-2.333613872528076, cup, nan]","[6.303003787994385, tablespoon, pepper]","[5.874743938446045, , nan]","[38.085105895996094, cubic centimetre, nan]","[19.858003616333008, cubic centimetre, nan]","[20.78607177734375, , nan]","[8.818887710571289, , nan]","[-29.02116584777832, , nan]","[-56.43180847167969, wheat flour, nan]","[-75.15351104736328, oil, nan]","[-96.9666519165039, crushed tomatoes, nan]","[-100.87362670898438, chocolate chips, hot dogs]","[-111.63285827636719, chicken thighs, nan]","[-100.33314514160156, vegetable oil, nan]","[-94.57388305664062, garlic powder, nan]","[-96.7287368774414, pineapple chunks, nan]","[-99.42351531982422, frozen French - cut green...","[-99.3680648803711, , nan]"
9,"[0.182712584733963, teaspoon, capers]","[-1.2653127908706665, , cup]","[9.031152725219727, teaspoon, jarlsberg cheese]","[-5.851704120635986, drop, ]","[8.173280715942383, tablespoon, teaspoon]","[16.656930923461914, , tablespoon]","[-3.6193811893463135, teaspoon, pine nuts]","[7.066552639007568, gram, salt and black pepper]","[-6.010429382324219, pound-mass, salt]","[-19.974388122558594, frozen chopped spinach, ]","[-25.137781143188477, garlic cloves, tablespoon]","[-21.06532096862793, fruit cocktail, ]","[-23.538969039916992, diced tomatoes, hot dogs]","[-21.884462356567383, garlic cloves, vanilla e...","[-24.010671615600586, garlic cloves, salt]","[-34.258392333984375, olive oil, ]","[-14.556727409362793, dark brown sugar, nan]","[-35.13624954223633, mandarin oranges, Belgian...","[-56.1386604309082, pound-mass, nan]"
