In [3]:
import torch

class LayerA(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super(LayerA, self).__init__()
        
        self.encoder = torch.nn.Linear(dim_in, dim_out)
        self.activation = torch.nn.ReLU()
    # end
    
    def forward(self, sequence):
        return self.activation(self.encoder(sequence))
    # end
# end

class HeadA(torch.nn.Module):
    def __init__(self, layer_encoder):
        super(HeadA, self).__init__()

        self.layer_encoder = layer_encoder
        
        dim_in = layer_encoder.encoder.out_features
        dim_out = layer_encoder.encoder.in_features
        self.decoder = torch.nn.Linear(dim_in, dim_out)
    # end
    
    def forward(self, sequence):
        return self.decoder(self.layer_encoder(sequence)).softmax(dim=-1)
    # end
# end


class HeadB(torch.nn.Module):
    def __init__(self, layer_encoder):
        super(HeadA, self).__init__()

        self.layer_encoder = layer_encoder
        
        dim_in = layer_encoder.encoder.out_features
        dim_out = layer_encoder.encoder.in_features
        self.decoder = torch.nn.Linear(dim_in, dim_out)
    # end
    
    def forward(self, sequence):
        return self.decoder(self.layer_encoder(sequence)).softmax(dim=-1)
    # end
# end

In [6]:
layer_a = LayerA(2,3)

In [9]:
torch.save(layer_a.state_dict(), './layera.pt')

In [19]:
torch.save(layer_a.state_dict(), './layera.pt')
layer_b = LayerA(2,3)
layer_b.load_state_dict(torch.load('./layera.pt'))
head_a = HeadA(layer_b)

In [24]:
torch.load('./layera.pt')

OrderedDict([('encoder.weight',
              tensor([[-0.6163,  0.6652],
                      [ 0.5920,  0.5670],
                      [-0.3521,  0.4188]])),
             ('encoder.bias', tensor([-0.5758, -0.2209,  0.1403]))])

In [27]:
from collections import OrderedDict

def update_state_dict_prefix(prefix, dict_state):
    dict_state_new = OrderedDict()
    for k_old, v in dict_state.items():
        dict_state_new[f'{prefix}.{k_old}'] = v
    # end
    
    return dict_state_new
# end

In [None]:
update_state_dict_prefix('layer_encoder', torch.load('./layera.pt'))


OrderedDict([('layer_encoder.encoder.weight',
              tensor([[-0.6163,  0.6652],
                      [ 0.5920,  0.5670],
                      [-0.3521,  0.4188]])),
             ('layer_encoder.encoder.bias',
              tensor([-0.5758, -0.2209,  0.1403]))])