In [None]:
# default_exp core

# Transfer
> Contains methods for transferring.

In [None]:
#export
from nbdev.showdoc import *
from transfertab.utils import *
import torch
import torch.nn as nn
import json
from functools import partial
from fastcore.foundation import *
from fastcore.dispatch import *
from transfertab.utils import *
import transfertab
import pandas as pd
import pathlib
from typing import Union

In [None]:
import os

We'll create collections of Embedding layers, which will be used to test our transfer methods.

In [None]:
emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))

In [None]:
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])

In [None]:
embed1

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

Now, we'll create collections containing required metadata.

In [None]:
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat2", "old_cat3")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat2" : ["a", "b"], "old_cat3" : ["A", "B"]}

In [None]:
json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())

`metadict` is a `Dict` with the keys as the classes in dest. model's data, and value is another `Dict` where `mapped_cat` corresponds to the class in src model's data, along with information about how the classes map from dest. data to src data.

In [None]:
metadict

{'new_cat1': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['a', 'b'],
   'new_class2': ['b'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat3',
  'classes_info': {'new_class1': ['A'], 'new_class2': []}}}

In [None]:
embed2

ModuleList(
  (0): Embedding(2, 10)
  (1): Embedding(2, 8)
)

In [None]:
df = pd.DataFrame({"old_cat1": [1, 2, 3, 4, 5], "old_cat2": ['a', 'b', 'b', 'b', 'a'], "old_cat3": ['A', 'B', 'B', 'B', 'A']})
cats = ("old_cat2", "old_cat3")
embdict = transfertab.utils.extractembeds(embed2, df, transfercats=cats, allcats=cats, path="tempwtbson")
embdict

{'old_cat2': {'classes': ['a', 'b'],
  'embeddings': [[0.4175843298435211,
    -1.0199278593063354,
    -0.6709107756614685,
    -0.500321090221405,
    -0.13969416916370392,
    -0.5671979188919067,
    1.4760546684265137,
    -0.7525286674499512,
    0.43304163217544556,
    -2.7244691848754883],
   [-1.2662526369094849,
    0.39920574426651,
    0.1494867205619812,
    -2.012317419052124,
    0.7739086747169495,
    1.636749267578125,
    0.8324260711669922,
    1.0281352996826172,
    0.9744652509689331,
    0.520490288734436]]},
 'old_cat3': {'classes': ['A', 'B'],
  'embeddings': [[1.4979701042175293,
    -1.1346012353897095,
    0.05064915120601654,
    -0.2791922688484192,
    0.9253252148628235,
    -0.9470803141593933,
    0.13601328432559967,
    0.7067786455154419],
   [0.5263708233833313,
    1.1203124523162842,
    0.8524508476257324,
    -0.8725061416625977,
    -0.4161680340766907,
    0.08518705517053604,
    -1.741043210029602,
    -0.3950035274028778]]}}

In [None]:
#export
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: nn.Module,
        /,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols, 
        oldcatdict, 
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    '''
        Transfers embeddings from `src_embeds` to `dest_embeds`, 
        with the help of collections containing various metadata.
    '''
    src_state_dict = L(src_embeds.state_dict().items())
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldidx = oldcatcols.index(metatransfer[newcat]["mapped_cat"])
        new_ps = torch.zeros(src_state_dict[oldidx][1].shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(oldcatdict[oldcatcols[oldidx]]).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs =  list(range(len(oldcatdict[oldcatcols[oldidx]])))
            ps = torch.unsqueeze(aggfn(torch.index_select(src_state_dict[oldidx][1], 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
        
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: dict,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols,  
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    dest_state_dict = L(dest_embeds.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldcatname = metatransfer[newcat]['mapped_cat']
        new_ps = torch.zeros(torch.tensor(src_embeds[oldcatname]['embeddings']).shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(src_embeds[oldcatname]['classes']).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs = list(range(len(src_embeds[oldcatname]['classes'])))
            ps = torch.unsqueeze(aggfn(torch.index_select(torch.tensor(src_embeds[oldcatname]['embeddings']), 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        dest_embeds.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
            
            
@typedispatch
def transferembeds_(
        dest_embeds: nn.Module, 
        src_embeds: (pathlib.PosixPath, str), 
        metatransfer,
        transfer_cats,
        *,
        kind = "bson",
        **kwargs):
    if kind == "json":
        with open(src_embeds, 'r') as fp:
            src_embeds = json.loads(fp.read())
    else:
        src_embeds = load_bson(src_embeds)
    transferembeds_(dest_embeds, src_embeds, metatransfer, transfer_cats, **kwargs);
    

In [None]:
transferembeds_

(Module,dict) -> transferembeds_
(Module,PosixPath) -> transferembeds_
(Module,str) -> transferembeds_
(Module,Module) -> transferembeds_

Embeddings before transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.4243, -0.3104, -0.2607, -1.2563,  0.3171,  0.5348,  1.1542,  0.1378,
                        0.7038, -1.1020],
                      [-1.2663,  0.3992,  0.1495, -2.0123,  0.7739,  1.6367,  0.8324,  1.0281,
                        0.9745,  0.5205],
                      [-0.4243, -0.3104, -0.2607, -1.2563,  0.3171,  0.5348,  1.1542,  0.1378,
                        0.7038, -1.1020]])),
             ('1.weight',
              tensor([[ 1.4980, -1.1346,  0.0506, -0.2792,  0.9253, -0.9471,  0.1360,  0.7068],
                      [ 1.0122, -0.0071,  0.4516, -0.5758,  0.2546, -0.4309, -0.8025,  0.1559]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.4176, -1.0199, -0.6709, -0.5003, -0.1397, -0.5672,  1.4761, -0.7525,
                        0.4330, -2.7245],
                      [-1.2663,  0.3992,  0.1495, -2.0123,  0.7739,  1.6367,  0.8324,  1.0281,
                        0.9745,  0.5205]])),
             ('1.weight',
              tensor([[ 1.4980, -1.1346,  0.0506, -0.2792,  0.9253, -0.9471,  0.1360,  0.7068],
                      [ 0.5264,  1.1203,  0.8525, -0.8725, -0.4162,  0.0852, -1.7410, -0.3950]]))])

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embdict, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embed2, metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, oldcatdict=oldcatdict, newcatdict=newcatdict)

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, pathlib.Path("tempwtbson"), metadict, transfer_cats, newcatcols=newcatcols, oldcatcols=oldcatcols, newcatdict=newcatdict)

Embeddings after transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.4243, -0.3104, -0.2607, -1.2563,  0.3171,  0.5348,  1.1542,  0.1378,
                        0.7038, -1.1020],
                      [-1.2663,  0.3992,  0.1495, -2.0123,  0.7739,  1.6367,  0.8324,  1.0281,
                        0.9745,  0.5205],
                      [-0.4243, -0.3104, -0.2607, -1.2563,  0.3171,  0.5348,  1.1542,  0.1378,
                        0.7038, -1.1020]])),
             ('1.weight',
              tensor([[ 1.4980, -1.1346,  0.0506, -0.2792,  0.9253, -0.9471,  0.1360,  0.7068],
                      [ 1.0122, -0.0071,  0.4516, -0.5758,  0.2546, -0.4309, -0.8025,  0.1559]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.4176, -1.0199, -0.6709, -0.5003, -0.1397, -0.5672,  1.4761, -0.7525,
                        0.4330, -2.7245],
                      [-1.2663,  0.3992,  0.1495, -2.0123,  0.7739,  1.6367,  0.8324,  1.0281,
                        0.9745,  0.5205]])),
             ('1.weight',
              tensor([[ 1.4980, -1.1346,  0.0506, -0.2792,  0.9253, -0.9471,  0.1360,  0.7068],
                      [ 0.5264,  1.1203,  0.8525, -0.8725, -0.4162,  0.0852, -1.7410, -0.3950]]))])

In [None]:
os.remove("tempwtbson")