# Import libraries

In [1]:
import ptan
import torch.nn as nn

# Test NN

In [2]:
class DQNNet(nn.Module):
    def __init__(self):
        super(DQNNet, self).__init__()
        
        self.ff = nn.Linear(5, 3)
        
    def forward(self, x):
        return self.ff(x)

# Instantiate NN

In [3]:
net = DQNNet()
print(net)

DQNNet(
  (ff): Linear(in_features=5, out_features=3, bias=True)
)


# Create target net

In [4]:
tgt_net = ptan.agent.TargetNet(net)
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)

Main net: Parameter containing:
tensor([[ 0.1235,  0.2126,  0.0866, -0.0449,  0.4043],
        [-0.2248,  0.1968,  0.2820,  0.1606, -0.0464],
        [ 0.0180,  0.2604,  0.1390,  0.0325,  0.1356]], requires_grad=True)
Target net: Parameter containing:
tensor([[ 0.1235,  0.2126,  0.0866, -0.0449,  0.4043],
        [-0.2248,  0.1968,  0.2820,  0.1606, -0.0464],
        [ 0.0180,  0.2604,  0.1390,  0.0325,  0.1356]], requires_grad=True)


# After updating main net

In [5]:
net.ff.weight.data += 1.0
print("After update")
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)

After update
Main net: Parameter containing:
tensor([[1.1235, 1.2126, 1.0866, 0.9551, 1.4043],
        [0.7752, 1.1968, 1.2820, 1.1606, 0.9536],
        [1.0180, 1.2604, 1.1390, 1.0325, 1.1356]], requires_grad=True)
Target net: Parameter containing:
tensor([[ 0.1235,  0.2126,  0.0866, -0.0449,  0.4043],
        [-0.2248,  0.1968,  0.2820,  0.1606, -0.0464],
        [ 0.0180,  0.2604,  0.1390,  0.0325,  0.1356]], requires_grad=True)


# Syncing target net with main net

In [6]:
tgt_net.sync()
print("After sync")
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)

After sync
Main net: Parameter containing:
tensor([[1.1235, 1.2126, 1.0866, 0.9551, 1.4043],
        [0.7752, 1.1968, 1.2820, 1.1606, 0.9536],
        [1.0180, 1.2604, 1.1390, 1.0325, 1.1356]], requires_grad=True)
Target net: Parameter containing:
tensor([[1.1235, 1.2126, 1.0866, 0.9551, 1.4043],
        [0.7752, 1.1968, 1.2820, 1.1606, 0.9536],
        [1.0180, 1.2604, 1.1390, 1.0325, 1.1356]], requires_grad=True)
