In [None]:
# default_exp utils

In [None]:
# Create a format for saving the embeddings and loading it back in.

In [None]:
# TODO: Add convenience method which takes data and infers the meta data inself.

In [None]:
#export
from fastai.tabular.all import *
from copy import deepcopy
import json

We'll create a fastai learner, and extract embeddings from it. But it will be possible to do so for any pytorch model.

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

In [None]:
learn = tabular_learner(dls, metrics=accuracy)

We'll need to create a metadata dictionary for the source data which contains all the categories, and the classes for each category in the given format.

In [None]:
meta = {
    "categories":['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    "workclass": {
        "classes": ['nan', ' Private', ' Self-emp-inc', ' Self-emp-not-inc', ' State-gov',
           ' Federal-gov', ' Local-gov', ' ?', ' Without-pay',
           ' Never-worked'],
    },
    'education': {
        "classes": ['nan', ' Assoc-acdm', ' Masters', ' HS-grad', ' Prof-school', ' 7th-8th',
       ' Some-college', ' 11th', ' Bachelors', ' Assoc-voc', ' 10th',
       ' 9th', ' Doctorate', ' 12th', ' 1st-4th', ' 5th-6th',
       ' Preschool']
    },
    "marital-status": {
        "classes": ['nan', ' Married-civ-spouse', ' Divorced', ' Never-married', ' Widowed',
       ' Married-spouse-absent', ' Separated', ' Married-AF-spouse']
    },
    "occupation": {
        "classes": ["nan", ' Exec-managerial', ' Prof-specialty', ' Other-service',
       ' Handlers-cleaners', ' Craft-repair', ' Adm-clerical', ' Sales',
       ' Machine-op-inspct', ' Transport-moving', ' ?',
       ' Farming-fishing', ' Tech-support', ' Protective-serv',
       ' Priv-house-serv', ' Armed-Forces']
    },
    "relationship": {
        "classes": ['nan', ' Wife', ' Not-in-family', ' Unmarried', ' Husband', ' Own-child',
       ' Other-relative']
    },
    "race": {
        "classes": ['nan', ' White', ' Black', ' Asian-Pac-Islander', ' Amer-Indian-Eskimo',
       ' Other']
    }
}

In [None]:
#export
def extractembeds(model, embeddinglg: str, metadict, path):
    '''
    model: Any pytorch model, containing a layergroup with all the embedding layers.
    embeddinglg: Name of the layer group containing the embedding layers.
    metadict: A dictionary containing relevant metadata. Check the format given in docs for further details.
    path: Path of the json 
    '''
    embedsdict = deepcopy(metadict)
    for i, cat in enumerate(metadict["categories"]):
        try:
            classes = metadict[cat]["classes"]
            layer = getattr(model, embeddinglg)[i]
            assert (layer.num_embeddings == len(classes)), "Embeddings should have same number of classes. Something might have gone wrong."
            embedsdict[cat]["embeddings"] = layer.weight.cpu().detach().numpy().tolist()
        except KeyError:
            pass
    with open(path, 'w') as fp:
        json.dump(embedsdict, fp)
    return embedsdict

In [None]:
extractembeds(learn.model, "embeds", meta, "test")

{'categories': ['workclass',
  'education',
  'marital-status',
  'occupation',
  'relationship',
  'race'],
 'workclass': {'classes': ['nan',
   ' Private',
   ' Self-emp-inc',
   ' Self-emp-not-inc',
   ' State-gov',
   ' Federal-gov',
   ' Local-gov',
   ' ?',
   ' Without-pay',
   ' Never-worked'],
  'embeddings': [[-0.005240431986749172,
    0.007816952653229237,
    0.0072440956719219685,
    0.013248818926513195,
    0.01408692728728056,
    -0.006864527706056833],
   [0.007001197896897793,
    -0.0007983995601534843,
    0.005386832635849714,
    -0.007644626311957836,
    -0.001206113025546074,
    -0.0027684723027050495],
   [-0.015234800055623055,
    -0.0018089546356350183,
    -0.0022030663676559925,
    4.4906235416419804e-05,
    0.005240086931735277,
    -0.006352895405143499],
   [-0.0013421158073469996,
    -6.775157089577988e-05,
    0.011917988769710064,
    0.014241461642086506,
    0.010808122344315052,
    3.164950976497494e-05],
   [-0.004133164882659912,
    0.