In [None]:
# default_exp extract

# Extraction Utilities
> Contains helpful functions for extracting embeddings and preparing data for it.

In [None]:
#export
import json
import torch.nn as nn
import pandas as pd
from fastcore.dispatch import *
from transfertab.utils import *

In [None]:
import os

In [1]:
from nbdev.test import *
test_nb('02_transfer.ipynb')

In [None]:
emb_szs = ((3, 10), (4, 8))

In [None]:
embed = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
embed

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(4, 8)
)

In [None]:
#export
class JSONizerWithBool(json.JSONEncoder):
    def default(self, obj):
        return super().encode(bool(obj)) \
            if isinstance(obj, np.bool_) \
            else super().default(obj)

In [None]:
#export
def _catdict2embedsdictstruct(catdict):
    embedsdict = {}
    for cat, classes in catdict.items():
        embedsdict[cat] = {}
        embedsdict[cat]["classes"] = classes
    return embedsdict

In [None]:
#export
@typedispatch
def extractembeds(model: nn.Module, df: pd.DataFrame, *, transfercats, allcats, path=None, kind="bson"):
    catdict = getcatdict(df, transfercats)
    return extractembeds(model, catdict, transfercats=transfercats, allcats=allcats, path=path, kind=kind)


@typedispatch
def extractembeds(model: nn.Module, catdict: dict, *, transfercats, allcats, path=None, kind="bson"):
    '''
    Extracts embedding weights from `model`, which can be further transferred to other models.
    
    model: Any pytorch model, containing the embedding layers.
    catdict: A dictionary with category as key, and classes as value.
    transfercats: Names of categories to be transferred.
    allcats: Names of all categories corresponding to the embedding layers in model.
    path: Path for the json to be stored.
    '''
    embedsdict = _catdict2embedsdictstruct(catdict)
    model_dict = list(model.state_dict().items())
    for i, cat in enumerate(transfercats):
        classes = catdict[cat]
        catidx = allcats.index(cat)
        assert (model_dict[catidx][1].shape[0] == len(classes)), \
            (f"embeddings dimension {model_dict[catidx][1].shape[0]} !="
            f"num of classes {len(classes)} for vairable {cat}. Embeddings should have"
            f"same number of classes. Something might have gone wrong.")
        embedsdict[cat]["embeddings"] = model_dict[catidx][1].numpy().tolist()
    if (path != None):
        with open(path, 'w') as fp:
            json.dump(embedsdict, fp, cls = JSONizerWithBool) if kind == "json" else store_bson(path, embedsdict)
    return embedsdict

In [None]:
df = pd.DataFrame({"cat1": [1, 2, 3, 4, 5], "cat2": ['a', 'b', 'c', 'b', 'a'], "cat3": ['A', 'B', 'C', 'D', 'A']})
df

Unnamed: 0,cat1,cat2,cat3
0,1,a,A
1,2,b,B
2,3,c,C
3,4,b,D
4,5,a,A


In [None]:
catdict = getcatdict(df, ("cat2", "cat3"))
cats = ("cat2", "cat3")

In [None]:
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtbson", kind="bson")
embdict

{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.31570491194725037,
    -0.07632226496934891,
    1.5683248043060303,
    -0.417350172996521,
    -0.10798821598291397,
    1.4268646240234375,
    -0.22982962429523468,
    -0.16915012896060944,
    0.002859442261978984,
    -0.4939035475254059],
   [0.6530274748802185,
    -0.5577511191368103,
    -0.9275949001312256,
    -0.06805138289928436,
    -2.2739336490631104,
    0.1566399186849594,
    -0.0531904362142086,
    -0.43463948369026184,
    -0.0794961154460907,
    0.4645240008831024],
   [1.0870261192321777,
    -0.22893156111240387,
    -0.253396600484848,
    -0.3393022119998932,
    -2.0341274738311768,
    -0.31127995252609253,
    0.3499477803707123,
    -1.9891204833984375,
    0.674164891242981,
    -1.3391718864440918]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[1.3585036993026733,
    0.024397719651460648,
    0.4804745614528656,
    1.1160022020339966,
    0.8734705448150635,
    0.7849490046

In [None]:
embdict = extractembeds(embed, df, transfercats=cats, allcats=cats, path="tempwtjson", kind="json")
embdict

{'cat2': {'classes': ['a', 'b', 'c'],
  'embeddings': [[0.31570491194725037,
    -0.07632226496934891,
    1.5683248043060303,
    -0.417350172996521,
    -0.10798821598291397,
    1.4268646240234375,
    -0.22982962429523468,
    -0.16915012896060944,
    0.002859442261978984,
    -0.4939035475254059],
   [0.6530274748802185,
    -0.5577511191368103,
    -0.9275949001312256,
    -0.06805138289928436,
    -2.2739336490631104,
    0.1566399186849594,
    -0.0531904362142086,
    -0.43463948369026184,
    -0.0794961154460907,
    0.4645240008831024],
   [1.0870261192321777,
    -0.22893156111240387,
    -0.253396600484848,
    -0.3393022119998932,
    -2.0341274738311768,
    -0.31127995252609253,
    0.3499477803707123,
    -1.9891204833984375,
    0.674164891242981,
    -1.3391718864440918]]},
 'cat3': {'classes': ['A', 'B', 'C', 'D'],
  'embeddings': [[1.3585036993026733,
    0.024397719651460648,
    0.4804745614528656,
    1.1160022020339966,
    0.8734705448150635,
    0.7849490046

In [None]:
load_bson("tempwtbson") == embdict

True

In [None]:
os.remove("tempwtbson")
os.remove("tempwtjson")

### Export

In [None]:
#export
_all_ = ['extractembeds']

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_extract.ipynb.
Converted 02_transfer.ipynb.
Converted 03_load_tests.ipynb.
Converted index.ipynb.
