In [None]:
# default_exp utils

In [None]:
# Create a format for saving the embeddings and loading it back in.

In [None]:
# TODO: Add convenience method which takes data and infers the meta data inself.

In [None]:
#export
from fastai.tabular.all import *
from copy import deepcopy
import json

We'll create a fastai learner, and extract embeddings from it. But it will be possible to do so for any pytorch model.

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

In [None]:
learn = tabular_learner(dls, metrics=accuracy)

We'll need to create a metadata dictionary for the source data which contains all the categories, and the classes for each category in the given format.

In [None]:
meta = {
    "categories":['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    "workclass": {
        "classes": ['nan', ' Private', ' Self-emp-inc', ' Self-emp-not-inc', ' State-gov',
           ' Federal-gov', ' Local-gov', ' ?', ' Without-pay',
           ' Never-worked'],
    },
    'education': {
        "classes": ['nan', ' Assoc-acdm', ' Masters', ' HS-grad', ' Prof-school', ' 7th-8th',
       ' Some-college', ' 11th', ' Bachelors', ' Assoc-voc', ' 10th',
       ' 9th', ' Doctorate', ' 12th', ' 1st-4th', ' 5th-6th',
       ' Preschool']
    },
    "marital-status": {
        "classes": ['nan', ' Married-civ-spouse', ' Divorced', ' Never-married', ' Widowed',
       ' Married-spouse-absent', ' Separated', ' Married-AF-spouse']
    },
    "occupation": {
        "classes": ["nan", ' Exec-managerial', ' Prof-specialty', ' Other-service',
       ' Handlers-cleaners', ' Craft-repair', ' Adm-clerical', ' Sales',
       ' Machine-op-inspct', ' Transport-moving', ' ?',
       ' Farming-fishing', ' Tech-support', ' Protective-serv',
       ' Priv-house-serv', ' Armed-Forces']
    },
    "relationship": {
        "classes": ['nan', ' Wife', ' Not-in-family', ' Unmarried', ' Husband', ' Own-child',
       ' Other-relative']
    },
    "race": {
        "classes": ['nan', ' White', ' Black', ' Asian-Pac-Islander', ' Amer-Indian-Eskimo',
       ' Other']
    }
}

In [None]:
#export
def extractembeds(model, embeddinglg: str, metadict, path):
    '''
    model: Any pytorch model, containing a layergroup with all the embedding layers.
    embeddinglg: Name of the layer group containing the embedding layers.
    metadict: A dictionary containing relevant metadata. Check the format given in docs for further details.
    path: Path of the json 
    '''
    embedsdict = deepcopy(metadict)
    for i, cat in enumerate(metadict["categories"]):
        try:
            classes = metadict[cat]["classes"]
            layer = getattr(model, embeddinglg)[i]
            assert (layer.num_embeddings == len(classes)), "Embeddings should have same number of classes. Something might have gone wrong."
            embedsdict[cat]["embeddings"] = layer.weight.cpu().detach().numpy().tolist()
        except KeyError:
            pass
    with open(path, 'w') as fp:
        json.dump(embedsdict, fp)
    return embedsdict

In [None]:
extractembeds(learn.model, "embeds", meta, "test")

{'categories': ['workclass',
  'education',
  'marital-status',
  'occupation',
  'relationship',
  'race'],
 'workclass': {'classes': ['nan',
   ' Private',
   ' Self-emp-inc',
   ' Self-emp-not-inc',
   ' State-gov',
   ' Federal-gov',
   ' Local-gov',
   ' ?',
   ' Without-pay',
   ' Never-worked'],
  'embeddings': [[0.007202747277915478,
    -0.012875864282250404,
    -0.014799817465245724,
    0.003126859199255705,
    -0.006304586306214333,
    0.010779300704598427],
   [0.015383994206786156,
    0.013579949736595154,
    -0.0007835030555725098,
    0.0016431818949058652,
    0.0019429499516263604,
    0.0049425954930484295],
   [-0.0030373292975127697,
    -0.0014501578407362103,
    0.009285704232752323,
    -0.0018107056384906173,
    -0.0052525997161865234,
    0.018953772261738777],
   [0.0018311218591406941,
    0.004234636202454567,
    -0.003884845180436969,
    0.00023053947370499372,
    -0.008045045658946037,
    -0.0100539680570364],
   [0.0046618664637207985,
    -0.