In [None]:
# default_exp core

# Transfer
> Contains methods for transferring.

In [None]:
#export
from nbdev.showdoc import *
from transfertab.utils import *
import torch
import torch.nn as nn
import json
from functools import partial
from fastcore.foundation import *
from fastcore.dispatch import *

We'll create collections of Embedding layers, which will be used to test our transfer methods.

In [None]:
emb_szs1 = ((3, 10), (2, 8))
emb_szs2 = ((2, 10), (2, 8))

In [None]:
embed1 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs1])
embed2 = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs2])

In [None]:
embed1

ModuleList(
  (0): Embedding(3, 10)
  (1): Embedding(2, 8)
)

Now, we'll create collections containing required metadata.

In [None]:
newcatcols = ("new_cat1", "new_cat2")
oldcatcols = ("old_cat1", "old_cat2")

newcatdict = {"new_cat1" : ["new_class1", "new_class2", "new_class3"], "new_cat2" : ["new_class1", "new_class2"]}
oldcatdict = {"old_cat1" : ["old_class1", "old_class2"], "old_cat2" : ["old_class1", "old_class2"]}

In [None]:
json_file_path = "../data/jsons/metadict.json"

with open(json_file_path, 'r') as j:
     metadict = json.loads(j.read())

`metadict` is a `Dict` with the keys as the classes in dest. model's data, and value is another `Dict` where `mapped_cat` corresponds to the class in src model's data, along with information about how the classes map from dest. data to src data.

In [None]:
metadict

{'new_cat1': {'mapped_cat': 'old_cat1',
  'classes_info': {'new_class1': ['old_class1', 'old_class2'],
   'new_class2': ['old_class2'],
   'new_class3': []}},
 'new_cat2': {'mapped_cat': 'old_cat2',
  'classes_info': {'new_class1': ['old_class2'], 'new_class2': []}}}

In [None]:
#export
@typedispatch
def transferembeds_(
        dest_model: nn.Module, 
        src_model: nn.Module,
        /,
        metatransfer,
        transfer_cats,
        *,
        newcatcols, 
        oldcatcols, 
        oldcatdict, 
        newcatdict, 
        aggfn = partial(torch.mean, dim=0)):
    '''
        Transfers embeddings from `src_model` to `dest_model`, 
        with the help of collections containing various metadata.
    '''
    src_state_dict = L(src_model.state_dict().items())
    dest_state_dict = L(dest_model.state_dict().items())
    for newcat in transfer_cats:
        newidx = newcatcols.index(newcat)
        oldidx = oldcatcols.index(metatransfer[newcat]["mapped_cat"])
        new_ps = torch.zeros(src_state_dict[oldidx][1].shape[1], 0)
        for newclass in newcatdict[newcat]:
            classidxs = L(oldcatdict[oldcatcols[oldidx]]).argwhere(lambda x: x in metatransfer[newcat]["classes_info"][newclass])
            if len(classidxs) == 0:
                classidxs =  list(range(len(oldcatdict[oldcatcols[oldidx]])))
            ps = torch.unsqueeze(aggfn(torch.index_select(src_state_dict[oldidx][1], 0, torch.LongTensor(classidxs))), -1)
            new_ps = torch.cat((new_ps, ps), dim=1)
        print(f"new param size: {new_ps.shape}\nold param size: {dest_model.state_dict()[dest_state_dict[newidx][0]].shape}\n")
        dest_model.state_dict()[dest_state_dict[newidx][0]].copy_(new_ps.T)
        
# @typedispatch
# def transferembeds_(
#         dest_model: nn.Module, 
#         src_embed_json: dict,
#
#         metatransfer,
#         transfer_cats,
#         *,
#         newcatcols, 
#         oldcatcols, 
#         oldcatdict, 
#         newcatdict, 
#         aggfn = partial(torch.mean, dim=0)):

# @typedispatch
# def transferembeds_(
#         dest_model: nn.Module, 
#         src_embed_json: pathlib.PosixPath, 
#         metatransfer,
#         transfer_cats,
#         *,
#         newcatcols, 
#         oldcatcols, 
#         oldcatdict, 
#         newcatdict, 
#         aggfn = partial(torch.mean, dim=0)):

In [None]:
from pathlib import Path

In [None]:
type(Path("a"))

pathlib.PosixPath

Embeddings before transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.8785, -1.3125, -1.7447, -0.0523,  0.2181, -0.4410, -0.2545, -0.7003,
                        0.8835,  0.1169],
                      [-0.2754,  1.6872, -1.4758, -0.6321, -0.8176, -0.5889, -0.7625,  1.1944,
                       -0.8908,  0.4763],
                      [ 0.8055,  1.6962,  0.7415, -0.9416, -0.4558,  0.3970,  1.4989, -0.2572,
                       -0.4770,  0.9706]])),
             ('1.weight',
              tensor([[ 0.5748, -2.2568, -0.1955,  0.1461,  0.8718,  1.2307, -0.3835, -1.2792],
                      [ 0.5399, -0.0471, -0.5947,  0.3210, -0.7245, -0.3612, -1.4094, -0.8337]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.8082,  0.0527, -0.4755, -0.2601, -1.5288, -1.4793, -1.2562,  1.7078,
                        1.3433,  0.1644],
                      [ 1.3142, -0.0689,  0.1832,  0.2934,  0.7709,  1.2324, -0.8176, -0.2325,
                        1.6991, -1.1931]])),
             ('1.weight',
              tensor([[-0.3152,  1.4867, -2.0957, -1.1404,  1.4763,  0.9344, -0.2639, -1.7845],
                      [ 1.3346,  1.3075, -0.7350,  0.4069, -1.7279, -1.3855,  0.3464,  0.1231]]))])

In [None]:
transfer_cats = ("new_cat1", "new_cat2")
transferembeds_(embed1, embed2, metadict, transfer_cats, newcatcols, oldcatcols, oldcatdict, newcatdict)

new param size: torch.Size([10, 3])
old param size: torch.Size([3, 10])

new param size: torch.Size([8, 2])
old param size: torch.Size([2, 8])



Embeddings after transfer:

In [None]:
embed1.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.2530, -0.0081, -0.1462,  0.0167, -0.3789, -0.1234, -1.0369,  0.7377,
                        1.5212, -0.5143],
                      [ 1.3142, -0.0689,  0.1832,  0.2934,  0.7709,  1.2324, -0.8176, -0.2325,
                        1.6991, -1.1931],
                      [ 0.2530, -0.0081, -0.1462,  0.0167, -0.3789, -0.1234, -1.0369,  0.7377,
                        1.5212, -0.5143]])),
             ('1.weight',
              tensor([[ 1.3346,  1.3075, -0.7350,  0.4069, -1.7279, -1.3855,  0.3464,  0.1231],
                      [ 0.5097,  1.3971, -1.4153, -0.3668, -0.1258, -0.2256,  0.0412, -0.8307]]))])

In [None]:
embed2.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.8082,  0.0527, -0.4755, -0.2601, -1.5288, -1.4793, -1.2562,  1.7078,
                        1.3433,  0.1644],
                      [ 1.3142, -0.0689,  0.1832,  0.2934,  0.7709,  1.2324, -0.8176, -0.2325,
                        1.6991, -1.1931]])),
             ('1.weight',
              tensor([[-0.3152,  1.4867, -2.0957, -1.1404,  1.4763,  0.9344, -0.2639, -1.7845],
                      [ 1.3346,  1.3075, -0.7350,  0.4069, -1.7279, -1.3855,  0.3464,  0.1231]]))])