In [18]:
import torch

class Encoder(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super(Encoder, self).__init__()
        
        self.l = torch.nn.Linear(dim_in, dim_out, bias=False)
        self.a = torch.nn.ReLU()
    # end
    
    def forward(self, sequence):
        return self.a(self.l(sequence))
    # end
# end

class Decoder(torch.nn.Module):
    def __init__(self, dim_in, dim_out, source):
        super(Decoder, self).__init__()
        
        self.l = torch.nn.Linear(dim_in, dim_out, bias=False)
        self.l.weight = torch.nn.Parameter(source.l.weight.t())
        self.a = torch.nn.ReLU()
        
    # end
    
    def forward(self, sequence):
        return self.a(self.l(sequence))
    # end
# end


In [19]:
e = Encoder(2,3)
d = Decoder(3,2,e)

In [52]:
seq = torch.Tensor([[8, 6]])

In [53]:
func_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(e.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)

In [58]:
e(seq)

tensor([[0., 0., 0.]], grad_fn=<ReluBackward0>)

In [59]:
loss = func_loss(e(seq), torch.Tensor([1,3,1]))
print(loss)
loss.backward()
optimizer.step()


tensor(3.6667, grad_fn=<MseLossBackward0>)


In [60]:
e.state_dict()

OrderedDict([('l.weight',
              tensor([[ 0.0926, -0.1739],
                      [-0.4285, -0.5430],
                      [-0.1478, -0.4638]]))])

In [62]:
d.state_dict()

OrderedDict([('l.weight',
              tensor([[ 0.0926, -0.4285, -0.1478],
                      [-0.1739, -0.5430, -0.4638]]))])

In [63]:
torch.save(e.state_dict(), './e.pt')
torch.save(d.state_dict(), './d.pt')

In [64]:
e2 = Encoder(2,3)
e2.load_state_dict(torch.load('./e.pt'))

In [69]:
e2.state_dict()

OrderedDict([('l.weight',
              tensor([[ 0.0926, -0.1739],
                      [-0.4285, -0.5430],
                      [-0.1478, -0.4638]]))])

In [68]:
d2 = Decoder(3,2,e2)
d2.load_state_dict(torch.load('./d.pt'))

In [80]:
optimizer = torch.optim.Adam(e2.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)

In [92]:
loss = func_loss(e2(seq), torch.Tensor([1,3,1]))
print(loss)
loss.backward()
optimizer.step()

tensor(3.6667, grad_fn=<MseLossBackward0>)


In [93]:
e2.state_dict()

OrderedDict([('l.weight',
              tensor([[ 0.0919, -0.1732],
                      [-0.4278, -0.5423],
                      [-0.1471, -0.4631]]))])

In [94]:
d2.state_dict()

OrderedDict([('l.weight',
              tensor([[ 0.0919, -0.4278, -0.1471],
                      [-0.1732, -0.5423, -0.4631]]))])