In [1]:
# default_exp transfer

# Transfer
> Contains methods for transferring.

In [2]:
#export
import pathlib
import torch
import torch.nn as nn
import json
import pandas as pd
from functools import partial
from fastcore.foundation import *
from fastcore.dispatch import *
from transfertab.utils import *
from transfertab.extract import *

In [3]:
#skip
#hide
from nbdev.showdoc import *

In [4]:
#skip
import os

We'll create collections of Embedding layers, which will be used to test our transfer methods.

In [5]:
emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))

In [6]:
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])

In [7]:
embed1

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

Now, we'll create collections containing required metadata.

In [8]:
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}

In [9]:
#skip
json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())

`metadict` is a `Dict` with the keys as the classes in dest. model's data, and value is another `Dict` where `mapped_cat` corresponds to the class in src model's data, along with information about how the classes map from dest. data to src data.

In [10]:
#hide
metadict = {'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

In [11]:
metadict

{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

In [12]:
df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")

In [13]:
#skip
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats, path="tempwtbson")

In [14]:
embdict = extractembeds(embed2, df, transfercats=cats, allcats=cats)

In [15]:
#export
def get_metadict_skeleton(df: pd.DataFrame, *, catcols=None, path=None):
    catdict = getcatdict(df, catcols)
    metadict = {}
    for (cat, classes) in catdict.items():
        metadict[cat] = {'mapped_cat': '', 'classes_info': {clas: [] for clas in classes}}
    if path != None:
        with open(path, 'w') as fp:
            json.dump(metadict, fp)
    return metadict

In [16]:
get_metadict_skeleton(df)

{'old_cat2': {'mapped_cat': '', 'classes_info': {'a': [], 'b': []}},
 'old_cat3': {'mapped_cat': '', 'classes_info': {'A': [], 'B': []}}}

In [17]:
#export
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: nn.Module,
        /,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols, 
        oldcatdict, 
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    '''
        Transfers embeddings from `src_embeds` to `dest_embeds`, 
        with the help of collections containing various metadata.
    '''
    src_state_dict = L(src_embeds.state_dict().items())
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldidx = oldcatcols.index(metatransfer[newcat]["mapped_cat"])
        new_ps = torch.zeros(src_state_dict[oldidx][1].shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(oldcatdict[oldcatcols[oldidx]]).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs =  list(range(len(oldcatdict[oldcatcols[oldidx]])))
            ps = torch.unsqueeze(aggfn(torch.index_select(src_state_dict[oldidx][1], 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
        
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: dict,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols,  
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldcatname = metatransfer[newcat]['mapped_cat']
        new_ps = torch.zeros(torch.tensor(src_embeds[oldcatname]['embeddings']).shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(src_embeds[oldcatname]['classes']).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs = list(range(len(src_embeds[oldcatname]['classes'])))
            ps = torch.unsqueeze(aggfn(torch.index_select(torch.tensor(src_embeds[oldcatname]['embeddings']), 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
            
            
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: pathlib.PosixPath, 
        metatransfer,
        transfer_cats,
        *,
        kind = "bson",
        **kwargs):
    if kind == "json":
        with open(src_embeds, 'r') as fp:
            src_embeds = json.loads(fp.read())
    else:
        src_embeds = load_bson(src_embeds)
    transferembeds_(dest_embeds, src_embeds, metatransfer, transfer_cats, **kwargs);
    

In [18]:
#skip
transferembeds_

(Module,Module) -> transferembeds_
(Module,dict) -> transferembeds_
(Module,PosixPath) -> transferembeds_
(Module,str) -> transferembeds_

Embeddings before transfer:

In [19]:
#skip
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.5078,  1.1805, -0.7152, -0.2090, -0.3791, -0.3538,  2.0874, -1.2128,
                        0.5305, -1.5456],
                      [-0.6374,  0.5955, -0.6828, -0.3867, -0.1883,  0.6705,  0.0604,  1.2784,
                       -0.0404,  0.5722],
                      [-0.8835, -1.9918,  0.8149, -0.4123, -0.6302,  0.0977, -0.4842, -0.6732,
                        1.2273, -0.4338]])),
             ('1.weight',
              tensor([[-0.2471, -1.1222,  1.2055,  0.5520, -0.4882, -0.3273, -0.7765,  0.4155],
                      [ 0.2077,  0.5769, -0.1930, -0.9279, -1.0475, -3.3709, -0.9812, -0.5121]]))])

In [20]:
#skip
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.6536,  0.1894, -0.4986, -0.1123,  0.3157,  1.4084,  0.5261,  1.0641,
                        0.5555, -2.8336],
                      [ 0.1326, -1.0560,  1.3056,  0.5426,  0.7622,  1.3548, -1.6767, -0.1300,
                        0.5395, -0.0292]])),
             ('1.weight',
              tensor([[-0.2570, -0.3869, -0.3253, -0.0735, -0.1460,  1.2832,  2.0547,  0.4501],
                      [ 0.9373, -0.3996,  2.0042, -0.8417, -0.9869,  1.8711, -1.2488, -0.6895]]))])

In [21]:
#skip
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

In [22]:
#skip
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embed2, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, oldcatdict=oldcatdict, newcatdict=newcatdict)

In [23]:
#skip
#skip
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, pathlib.Path("tempwtbson"), metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

Embeddings after transfer:

In [24]:
#skip
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.2605, -0.4333,  0.4035,  0.2152,  0.5390,  1.3816, -0.5753,  0.4670,
                        0.5475, -1.4314],
                      [ 0.1326, -1.0560,  1.3056,  0.5426,  0.7622,  1.3548, -1.6767, -0.1300,
                        0.5395, -0.0292],
                      [-0.2605, -0.4333,  0.4035,  0.2152,  0.5390,  1.3816, -0.5753,  0.4670,
                        0.5475, -1.4314]])),
             ('1.weight',
              tensor([[-0.2570, -0.3869, -0.3253, -0.0735, -0.1460,  1.2832,  2.0547,  0.4501],
                      [ 0.3402, -0.3932,  0.8394, -0.4576, -0.5665,  1.5772,  0.4030, -0.1197]]))])

In [25]:
#skip
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.6536,  0.1894, -0.4986, -0.1123,  0.3157,  1.4084,  0.5261,  1.0641,
                        0.5555, -2.8336],
                      [ 0.1326, -1.0560,  1.3056,  0.5426,  0.7622,  1.3548, -1.6767, -0.1300,
                        0.5395, -0.0292]])),
             ('1.weight',
              tensor([[-0.2570, -0.3869, -0.3253, -0.0735, -0.1460,  1.2832,  2.0547,  0.4501],
                      [ 0.9373, -0.3996,  2.0042, -0.8417, -0.9869,  1.8711, -1.2488, -0.6895]]))])

In [26]:
#skip
#hide
# skip
os.remove("tempwtbson")

### Export

In [27]:
#export
#skip
_all_ = ['transferembeds_']

In [28]:
#hide
#skip
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_extract.ipynb.
Converted 02_transfer.ipynb.
Converted 03_load_tests.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
