In [None]:
# default_exp utils

# Extraction Utilities
> Contains helpful functions for extracting embeddings and preparing data for it.

In [None]:
#export
from copy import deepcopy
import json
import bson
import torch
import torch.nn as nn
import pandas as pd
from fastcore.dispatch import *

In [None]:
import os

In [None]:
emb_szs = ((3, 10), (4, 8))

In [None]:
embed = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
embed

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(4, 8)
)

In [None]:
#export
class JSONizerWithBool(json.JSONEncoder):
    def default(self, obj):
        return super().encode(bool(obj)) \
            if isinstance(obj, np.bool_) \
            else super().default(obj)

In [None]:
#export
def getcatdict(df, catcols=None):
    if catcols == None:
        columns = list(df.columns)
        catcols = list(map(lambda arg: columns[arg[0]], filter(lambda arg: arg[1] == 'object', enumerate(df.dtypes))))
    catdict = {}
    for cat in catcols:
        catdict[cat] = list(df[cat].unique())
    return catdict

def _catdict2embedsdictstruct(catdict):
    embedsdict = {}
    for cat, classes in catdict.items():
        embedsdict[cat] = {}
        embedsdict[cat]["classes"] = classes
    return embedsdict

In [None]:
#export
def store_bson(path, data):
    bdata = bson.dumps(data)
    with open(path, "wb") as fp:
        fp.write(bdata)

def load_bson(path):
    with open(path, "rb") as fp:
        bdata = fp.read()
    return bson.loads(bdata)

In [None]:
#hide
issubclass(type(embed), nn.Module)

True

In [None]:
#export
@typedispatch
def extractembeds(model: nn.Module, df: pd.DataFrame, *, transfercats, allcats, path=None, kind="bson"):
    catdict = getcatdict(df, transfercats)
    return extractembeds(model, catdict, transfercats=transfercats, allcats=allcats, path=path, kind=kind)


@typedispatch
def extractembeds(model: nn.Module, catdict: dict, *, transfercats, allcats, path=None, kind="bson"):
    '''
    Extracts embedding weights from `model`, which can be further transferred to other models.
    
    model: Any pytorch model, containing the embedding layers.
    catdict: A dictionary with category as key, and classes as value.
    transfercats: Names of categories to be transferred.
    allcats: Names of all categories corresponding to the embedding layers in model.
    path: Path for the json to be stored.
    '''
    embedsdict = _catdict2embedsdictstruct(catdict)
    model_dict = list(model.state_dict().items())
    for i, cat in enumerate(transfercats):
        classes = catdict[cat]
        catidx = allcats.index(cat)
        assert (model_dict[catidx][1].shape[0] == len(classes)), \
            (f"embeddings dimension {model_dict[catidx][1].shape[0]} !="
            f"num of classes {len(classes)} for vairable {cat}. Embeddings should have"
            f"same number of classes. Something might have gone wrong.")
        embedsdict[cat]["embeddings"] = model_dict[catidx][1].numpy().tolist()
    if (path != None):
        with open(path, 'w') as fp:
            json.dump(embedsdict, fp, cls = JSONizerWithBool) if kind == "json" else store_bson(path, embedsdict)
    return embedsdict

In [None]:
df = pd.DataFrame({"cat1": [1, 2, 3, 4, 5], "cat2": ['a', 'b', 'c', 'b', 'a'], "cat3": ['A', 'B', 'C', 'D', 'A']})
df

Unnamed: 0,cat1,cat2,cat3
0,1,a,A
1,2,b,B
2,3,c,C
3,4,b,D
4,5,a,A


In [None]:
catdict = getcatdict(df, ("cat2", "cat3"))
catdict

{'cat2': ['a', 'b', 'c'], 'cat3': ['A', 'B', 'C', 'D']}

In [None]:
cats = ("cat2", "cat3")

In [None]:
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtbson", kind="bson")
embdict

{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.05603214353322983,
    1.3744007349014282,
    -0.31262442469596863,
    0.37519311904907227,
    -1.9234352111816406,
    0.6861013770103455,
    -1.4775152206420898,
    1.109705924987793,
    0.10772546380758286,
    -1.831138014793396],
   [-0.6345962285995483,
    -1.946880578994751,
    -0.3423684537410736,
    0.04363042116165161,
    0.2794189751148224,
    0.06722379475831985,
    -0.28763726353645325,
    0.9291570782661438,
    0.00894381757825613,
    0.7326543927192688],
   [0.4946509003639221,
    -0.409279465675354,
    0.5394951701164246,
    0.08757159858942032,
    -0.5254548192024231,
    0.6215347647666931,
    -1.835079550743103,
    -1.5779095888137817,
    1.0498842000961304,
    -0.7102258801460266]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[-0.7851612567901611,
    1.226929783821106,
    -0.7716110944747925,
    0.1601991206407547,
    -0.11754895001649857,
    -1.1803913116455078,
 

In [None]:
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtjson", kind="json")
embdict

{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.05603214353322983,
    1.3744007349014282,
    -0.31262442469596863,
    0.37519311904907227,
    -1.9234352111816406,
    0.6861013770103455,
    -1.4775152206420898,
    1.109705924987793,
    0.10772546380758286,
    -1.831138014793396],
   [-0.6345962285995483,
    -1.946880578994751,
    -0.3423684537410736,
    0.04363042116165161,
    0.2794189751148224,
    0.06722379475831985,
    -0.28763726353645325,
    0.9291570782661438,
    0.00894381757825613,
    0.7326543927192688],
   [0.4946509003639221,
    -0.409279465675354,
    0.5394951701164246,
    0.08757159858942032,
    -0.5254548192024231,
    0.6215347647666931,
    -1.835079550743103,
    -1.5779095888137817,
    1.0498842000961304,
    -0.7102258801460266]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[-0.7851612567901611,
    1.226929783821106,
    -0.7716110944747925,
    0.1601991206407547,
    -0.11754895001649857,
    -1.1803913116455078,
 

In [None]:
load_bson("tempwtbson") == embdict

True

In [None]:
os.remove("tempwtbson")
os.remove("tempwtjson")