In [None]:
# default_exp utils

# Extraction Utilities
> Contains helpful functions for extracting embeddings and preparing data for it.

In [None]:
#export
from copy import deepcopy
import json
import torch
import torch.nn as nn
import pandas as pd
from fastcore.dispatch import *

In [None]:
emb_szs = ((3, 10), (4, 8))

In [None]:
embed = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
embed

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(4, 8)
)

In [None]:
#export
class JSONizerWithBool(json.JSONEncoder):
    def default(self, obj):
        return super().encode(bool(obj)) \
            if isinstance(obj, np.bool_) \
            else super().default(obj)

In [None]:
#export
def getcatdict(df, catcols=None):
    if catcols == None:
        columns = list(df.columns)
        catcols = list(map(lambda arg: columns[arg[0]], filter(lambda arg: arg[1] == 'object', enumerate(df.dtypes))))
    catdict = {}
    for cat in catcols:
        catdict[cat] = list(df[cat].unique())
    return catdict

def _catdict2embedsdictstruct(catdict):
    embedsdict = {}
    for cat, classes in catdict.items():
        embedsdict[cat] = {}
        embedsdict[cat]["classes"] = classes
    return embedsdict

In [None]:
#hide
issubclass(type(embed), nn.Module)

True

In [None]:
#export
@typedispatch
def extractembeds(model: nn.Module, df: pd.DataFrame, *, transfercats, allcats, path=None):
    catdict = getcatdict(df, transfercats)
    return extractembeds(model, catdict, transfercats=transfercats, allcats=allcats, path=path)


@typedispatch
def extractembeds(model: nn.Module, catdict: dict, *, transfercats, allcats, path=None):
    '''
    Extracts embedding weights from `model`, which can be further transferred to other models.
    
    model: Any pytorch model, containing the embedding layers.
    catdict: A dictionary with category as key, and classes as value.
    transfercats: Names of categories to be transferred.
    allcats: Names of all categories corresponding to the embedding layers in model.
    path: Path for the json to be stored.
    '''
    embedsdict = _catdict2embedsdictstruct(catdict)
    model_dict = list(model.state_dict().items())
    for i, cat in enumerate(transfercats):
        classes = catdict[cat]
        catidx = allcats.index(cat)
        assert (model_dict[catidx][1].shape[0] == len(classes)), \
            (f"embeddings dimension {model_dict[catidx][1].shape[0]} !="
            f"num of classes {len(classes)} for vairable {cat}. Embeddings should have"
            f"same number of classes. Something might have gone wrong.")
        embedsdict[cat]["embeddings"] = model_dict[catidx][1].numpy().tolist()
    if (path != None):
        with open(path, 'w') as fp:
            json.dump(embedsdict, fp, cls = JSONizerWithBool)
    return embedsdict

In [None]:
df = pd.DataFrame({"cat1": [1, 2, 3, 4, 5], "cat2": ['a', 'b', 'c', 'b', 'a'], "cat3": ['A', 'B', 'C', 'D', 'A']})
df

Unnamed: 0,cat1,cat2,cat3
0,1,a,A
1,2,b,B
2,3,c,C
3,4,b,D
4,5,a,A


In [None]:
catdict = getcatdict(df, ("cat2", "cat3"))
catdict

{'cat2': ['a', 'b', 'c'], 'cat3': ['A', 'B', 'C', 'D']}

In [None]:
cats = ("cat2", "cat3")

In [None]:
extractembeds(embed, df, transfercats=cats, allcats=cats)

{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.6051239371299744,
    0.3564712107181549,
    -1.5467100143432617,
    1.5750962495803833,
    -0.42299988865852356,
    1.3493329286575317,
    1.3607025146484375,
    -1.0899522304534912,
    -0.553862452507019,
    1.3176639080047607],
   [-0.07791764289140701,
    -1.3256100416183472,
    0.8737502098083496,
    0.3552184998989105,
    0.20756079256534576,
    0.7821463942527771,
    -0.18401317298412323,
    1.9141485691070557,
    1.0479087829589844,
    0.641265332698822],
   [0.24757033586502075,
    0.2565772533416748,
    1.8610639572143555,
    0.40511053800582886,
    -0.4214091897010803,
    -0.6595247983932495,
    0.35093241930007935,
    2.0563204288482666,
    -0.8007354736328125,
    -2.4489898681640625]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[-1.4485093355178833,
    -0.05285371094942093,
    0.793982744216919,
    0.7271164059638977,
    1.482647180557251,
    -0.8566319942474365,
    