# Extraction Utilities
> Contains helpful functions for extracting embeddings and preparing data for it.

In [None]:
# default_exp utils

In [None]:
# Create a format for saving the embeddings and loading it back in.

In [None]:
#export
from fastai.tabular.all import *
from copy import deepcopy
import json

We'll create a fastai learner, and extract embeddings from it. But it will be possible to do so for any pytorch model.

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])

In [None]:
learn = tabular_learner(dls, metrics=accuracy)

We'll need to create a metadata dictionary for the source data which contains all the categories, and the classes for each category in the given format.

In [None]:
meta = {
    "categories":['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    "workclass": {
        "classes": ['nan', ' Private', ' Self-emp-inc', ' Self-emp-not-inc', ' State-gov',
           ' Federal-gov', ' Local-gov', ' ?', ' Without-pay',
           ' Never-worked'],
    },
    'education': {
        "classes": ['nan', ' Assoc-acdm', ' Masters', ' HS-grad', ' Prof-school', ' 7th-8th',
       ' Some-college', ' 11th', ' Bachelors', ' Assoc-voc', ' 10th',
       ' 9th', ' Doctorate', ' 12th', ' 1st-4th', ' 5th-6th',
       ' Preschool']
    },
    "marital-status": {
        "classes": ['nan', ' Married-civ-spouse', ' Divorced', ' Never-married', ' Widowed',
       ' Married-spouse-absent', ' Separated', ' Married-AF-spouse']
    },
    "occupation": {
        "classes": ["nan", ' Exec-managerial', ' Prof-specialty', ' Other-service',
       ' Handlers-cleaners', ' Craft-repair', ' Adm-clerical', ' Sales',
       ' Machine-op-inspct', ' Transport-moving', ' ?',
       ' Farming-fishing', ' Tech-support', ' Protective-serv',
       ' Priv-house-serv', ' Armed-Forces']
    },
    "relationship": {
        "classes": ['nan', ' Wife', ' Not-in-family', ' Unmarried', ' Husband', ' Own-child',
       ' Other-relative']
    },
    "race": {
        "classes": ['nan', ' White', ' Black', ' Asian-Pac-Islander', ' Amer-Indian-Eskimo',
       ' Other']
    }
}

In [None]:
#export

#Meta Format: {'categories':['cat1','cat2'......],'cat1':{'classes':['class1','class2'......]},'cat2':...........}
def extract_meta_from_learner(learner):
    cat_list = learner.dls.cat_names
    temp_meta = dict.fromkeys(cat_list)
    t = learner.dls.classes
    meta = {'categories':cat_list}
    for i in t:
        temp_meta[i] = {'classes':t[i]}
    meta.update(temp_meta)
    return meta

def extract_meta_from_df(df: pd.DataFrame):
    columns = [x for x in df.columns]
    cat_list = []
    for i,j in enumerate(df.dtypes):
        if j == 'object':
            cat_list.append(columns[i])
    cat_dict = dict.fromkeys(cat_list)
    meta={'categories':cat_list}
    for i in cat_dict:
        cat_dict[i] = {}
        cat_dict[i]["classes"] = list(df[i].unique())
    meta.update(cat_dict)
    return meta

In [None]:
#export
def extractembeds(model, embeddinglg: str, metadict, path):
    '''
    model: Any pytorch model, containing a layergroup with all the embedding layers.
    embeddinglg: Name of the layer group containing the embedding layers.
    metadict: A dictionary containing relevant metadata. Check the format given in docs for further details.
    path: Path of the json 
    '''
    embedsdict = deepcopy(metadict)
    for i, cat in enumerate(metadict["categories"]):
        try:
            classes = metadict[cat]["classes"]
            layer = getattr(model, embeddinglg)[i]
            assert (layer.num_embeddings == len(classes)), f"embeddings dimension {layer.num_embeddings} !=  num of classes {len(classes)} for vairable {cat}. Embeddings should have same number of classes. Something might have gone wrong."
            embedsdict[cat]["embeddings"] = layer.weight.cpu().detach().numpy().tolist()
        except KeyError:
            pass
    with open(path, 'w') as fp:
        json.dump(embedsdict, fp)
    return embedsdict

In [None]:
embeddict = extractembeds(learn.model, "embeds", meta, "test")

In [None]:
embeddict["workclass"]["embeddings"][:1]

[[0.011566209606826305,
  -0.009608260355889797,
  -0.005055208690464497,
  0.012445224449038506,
  -0.003585385624319315,
  0.0029318814631551504]]