In [None]:
# default_exp transfer

# Transfer
> Contains methods for transferring.

In [None]:
#export
import pathlib
import torch
import torch.nn as nn
import json
import pandas as pd
from functools import partial
from fastcore.foundation import *
from fastcore.dispatch import *
from transfertab.utils import *
from transfertab.extract import *

In [None]:
#skip
#hide
from nbdev.showdoc import *

In [None]:
import os

We'll create collections of Embedding layers, which will be used to test our transfer methods.

In [None]:
emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))

In [None]:
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])

In [None]:
embed1

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

Now, we'll create collections containing required metadata.

In [None]:
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}

In [None]:
json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())

`metadict` is a `Dict` with the keys as the classes in dest. model's data, and value is another `Dict` where `mapped_cat` corresponds to the class in src model's data, along with information about how the classes map from dest. data to src data.

In [None]:
metadict

{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

In [None]:
df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")

In [None]:
#skip
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats, path="tempwtbson")

In [None]:
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats)

In [None]:
#export
def get_metadict_skeleton(df: pd.DataFrame, *, catcols=None, path=None):
    catdict = getcatdict(df, catcols)
    metadict = {}
    for (cat, classes) in catdict.items():
        metadict[cat] = {'mapped_cat': '', 'classes_info': {clas: [] for clas in classes}}
    if path != None:
        with open(path, 'w') as fp:
            json.dump(metadict, fp)
    return metadict

In [None]:
get_metadict_skeleton(df)

{'old_cat2': {'mapped_cat': '', 'classes_info': {'a': [], 'b': []}},
 'old_cat3': {'mapped_cat': '', 'classes_info': {'A': [], 'B': []}}}

In [None]:
#export
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: nn.Module,
        /,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols, 
        oldcatdict, 
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    '''
        Transfers embeddings from `src_embeds` to `dest_embeds`, 
        with the help of collections containing various metadata.
    '''
    src_state_dict = L(src_embeds.state_dict().items())
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldidx = oldcatcols.index(metatransfer[newcat]["mapped_cat"])
        new_ps = torch.zeros(src_state_dict[oldidx][1].shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(oldcatdict[oldcatcols[oldidx]]).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs =  list(range(len(oldcatdict[oldcatcols[oldidx]])))
            ps = torch.unsqueeze(aggfn(torch.index_select(src_state_dict[oldidx][1], 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
        
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: dict,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols,  
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldcatname = metatransfer[newcat]['mapped_cat']
        new_ps = torch.zeros(torch.tensor(src_embeds[oldcatname]['embeddings']).shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(src_embeds[oldcatname]['classes']).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs = list(range(len(src_embeds[oldcatname]['classes'])))
            ps = torch.unsqueeze(aggfn(torch.index_select(torch.tensor(src_embeds[oldcatname]['embeddings']), 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
            
            
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: (pathlib.PosixPath, str), 
        metatransfer,
        transfer_cats,
        *,
        kind = "bson",
        **kwargs):
    if kind == "json":
        with open(src_embeds, 'r') as fp:
            src_embeds = json.loads(fp.read())
    else:
        src_embeds = load_bson(src_embeds)
    transferembeds_(dest_embeds, src_embeds, metatransfer, transfer_cats, **kwargs);
    

In [None]:
transferembeds_

(Module,Module) -> transferembeds_
(Module,dict) -> transferembeds_
(Module,PosixPath) -> transferembeds_
(Module,str) -> transferembeds_

Embeddings before transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-2.1244e+00, -1.2137e+00,  5.0160e-01,  1.1785e+00, -1.1206e+00,
                        1.2005e+00, -4.1906e-02,  1.6484e+00, -1.8413e-01, -7.3832e-01],
                      [-1.5847e+00, -1.2023e+00,  1.0965e+00,  3.5035e-01, -1.1027e-01,
                       -5.2887e-02,  9.9228e-01,  7.0271e-01,  1.6446e-04, -1.3017e+00],
                      [-7.9990e-01,  2.3920e-01, -1.6605e+00,  4.5783e-01, -9.4039e-01,
                       -1.0398e+00, -9.1239e-01, -6.1318e-01, -6.4078e-01,  8.1879e-01]])),
             ('1.weight',
              tensor([[-1.4255, -1.2286,  0.6272, -0.4778, -0.6330, -2.6385,  1.2995,  1.3089],
                      [ 0.2308, -0.2324,  1.3043,  0.3647, -0.9237, -0.0981,  0.4950, -0.7677]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.0999,  0.9369, -2.1694,  0.9061, -0.0143, -2.5886,  0.4395,  0.3534,
                       -0.8270, -0.6671],
                      [-0.8343,  1.6814, -0.5556, -1.7140, -0.0165, -1.1426, -1.1122, -0.0464,
                        1.4701, -1.0296]])),
             ('1.weight',
              tensor([[-0.6503, -1.0543,  1.4842, -1.9728, -2.0252, -0.1327,  0.6693,  0.2248],
                      [-0.6789,  1.6921, -0.4586,  0.2002,  0.5933,  0.6757,  1.1067, -0.2443]]))])

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embed2, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, oldcatdict=oldcatdict, newcatdict=newcatdict)

In [None]:
#skip
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, pathlib.Path("tempwtbson"), metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

Embeddings after transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.3672,  1.3092, -1.3625, -0.4039, -0.0154, -1.8656, -0.3363,  0.1535,
                        0.3215, -0.8483],
                      [-0.8343,  1.6814, -0.5556, -1.7140, -0.0165, -1.1426, -1.1122, -0.0464,
                        1.4701, -1.0296],
                      [-0.3672,  1.3092, -1.3625, -0.4039, -0.0154, -1.8656, -0.3363,  0.1535,
                        0.3215, -0.8483]])),
             ('1.weight',
              tensor([[-0.6503, -1.0543,  1.4842, -1.9728, -2.0252, -0.1327,  0.6693,  0.2248],
                      [-0.6646,  0.3189,  0.5128, -0.8863, -0.7160,  0.2715,  0.8880, -0.0097]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.0999,  0.9369, -2.1694,  0.9061, -0.0143, -2.5886,  0.4395,  0.3534,
                       -0.8270, -0.6671],
                      [-0.8343,  1.6814, -0.5556, -1.7140, -0.0165, -1.1426, -1.1122, -0.0464,
                        1.4701, -1.0296]])),
             ('1.weight',
              tensor([[-0.6503, -1.0543,  1.4842, -1.9728, -2.0252, -0.1327,  0.6693,  0.2248],
                      [-0.6789,  1.6921, -0.4586,  0.2002,  0.5933,  0.6757,  1.1067, -0.2443]]))])

In [None]:
#hide
#skip
os.remove("tempwtbson")

### Export

In [None]:
#export
_all_ = ['transferembeds_']

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_extract.ipynb.
Converted 02_transfer.ipynb.
Converted 03_load_tests.ipynb.
Converted index.ipynb.
