In [1]:
import torch
from tensordict import TensorDict, from_module
from tensordict.nn import set_composite_lp_aggregate, TensorDictModule
import copy

## Get new model for 2 agents

In [2]:
filename = "/home/hieule/research/saris_revised/local_assets/Hallway_L/hallway_focus_3_agents/checkpoint_187.pt"
model = torch.load(filename)

In [3]:
model.keys()

dict_keys(['policy', 'critic', 'loss_module', 'optimizer'])

In [4]:
policy = model['policy']
policy.keys()
new_policy = copy.deepcopy(policy)
for k, v in policy.items():
    if '__' not in k:
        new_policy[k] = v[:2]  # Extract the first two agents' weights and biases
        print(k, new_policy[k].shape)
    elif "batch_size" in k:
        new_policy[k] = torch.Size([2])  # Adjust batch size for the new policy
        print(k, new_policy[k])
    else:
        new_policy[k] = v
        print(k, new_policy[k])

module.0.module.0.params.0.weight torch.Size([2, 256, 9])
module.0.module.0.params.0.bias torch.Size([2, 256])
module.0.module.0.params.2.weight torch.Size([2, 256, 256])
module.0.module.0.params.2.bias torch.Size([2, 256])
module.0.module.0.params.4.weight torch.Size([2, 6, 256])
module.0.module.0.params.4.bias torch.Size([2, 6])
module.0.module.0.params.__batch_size torch.Size([2])
module.0.module.0.params.__device None


In [5]:
critic = model['critic']
critic.keys()
new_critic = copy.deepcopy(critic)
for k, v in critic.items():
    if '__' not in k:
        new_critic[k] = v[:2]  # Extract the first two agents' weights and biases
        if 'module.params.0.weight' in k:
            new_critic[k] = new_critic[k][..., :18]
        print(k, new_critic[k].shape)
    elif "batch_size" in k:
        new_critic[k] = torch.Size([2])  # Adjust batch size for the new critic
        print(k, new_critic[k])
    else:
        new_critic[k] = v
        print(k, new_critic[k])

module.params.0.weight torch.Size([2, 256, 18])
module.params.0.bias torch.Size([2, 256])
module.params.2.weight torch.Size([2, 256, 256])
module.params.2.bias torch.Size([2, 256])
module.params.4.weight torch.Size([2, 1, 256])
module.params.4.bias torch.Size([2, 1])
module.params.__batch_size torch.Size([2])
module.params.__device None


In [6]:
loss_module = model['loss_module']
loss_module.keys()
new_loss_module = copy.deepcopy(loss_module)
for k, v in loss_module.items():
    if '__' not in k and 'params' in k:
        new_loss_module[k] = v[:2]  # Extract the first two agents' weights and biases
        if 'critic_network_params.module.params.0.weight' in k:
            new_loss_module[k] = new_loss_module[k][..., :18]
        print(k, new_loss_module[k].shape)
    elif "batch_size" in k:
        new_loss_module[k] = torch.Size([2])  # Adjust batch size for the new loss_module
        print(k, new_loss_module[k])
    else:
        new_loss_module[k] = v
        print(k, new_loss_module[k])

entropy_coef tensor(1.0000e-04, device='cuda:0')
critic_coef tensor(1., device='cuda:0')
clip_epsilon tensor(0.2000, device='cuda:0')
actor_network_params.module.0.module.0.params.0.weight torch.Size([2, 256, 9])
actor_network_params.module.0.module.0.params.0.bias torch.Size([2, 256])
actor_network_params.module.0.module.0.params.2.weight torch.Size([2, 256, 256])
actor_network_params.module.0.module.0.params.2.bias torch.Size([2, 256])
actor_network_params.module.0.module.0.params.4.weight torch.Size([2, 6, 256])
actor_network_params.module.0.module.0.params.4.bias torch.Size([2, 6])
actor_network_params.__batch_size torch.Size([2])
actor_network_params.__device None
critic_network_params.module.params.0.weight torch.Size([2, 256, 18])
critic_network_params.module.params.0.bias torch.Size([2, 256])
critic_network_params.module.params.2.weight torch.Size([2, 256, 256])
critic_network_params.module.params.2.bias torch.Size([2, 256])
critic_network_params.module.params.4.weight torch.Si

In [7]:
optimizer = model['optimizer']
optimizer.keys()
new_optimizer = copy.deepcopy(optimizer)
for k, state_idx in new_optimizer['state'].items():
    state_idx['exp_avg'] = state_idx['exp_avg'][:2]
    state_idx['exp_avg_sq'] = state_idx['exp_avg_sq'][:2]
    if k == 6:
        state_idx['exp_avg'] = state_idx['exp_avg'][..., :18]
        state_idx['exp_avg_sq'] = state_idx['exp_avg_sq'][..., :18]
for k, state_idx in new_optimizer['state'].items():
    print(k, state_idx['exp_avg'].shape, state_idx['exp_avg_sq'].shape)

0 torch.Size([2, 256, 9]) torch.Size([2, 256, 9])
1 torch.Size([2, 256]) torch.Size([2, 256])
2 torch.Size([2, 256, 256]) torch.Size([2, 256, 256])
3 torch.Size([2, 256]) torch.Size([2, 256])
4 torch.Size([2, 6, 256]) torch.Size([2, 6, 256])
5 torch.Size([2, 6]) torch.Size([2, 6])
6 torch.Size([2, 256, 18]) torch.Size([2, 256, 18])
7 torch.Size([2, 256]) torch.Size([2, 256])
8 torch.Size([2, 256, 256]) torch.Size([2, 256, 256])
9 torch.Size([2, 256]) torch.Size([2, 256])
10 torch.Size([2, 1, 256]) torch.Size([2, 1, 256])
11 torch.Size([2, 1]) torch.Size([2, 1])


In [8]:
new_model = {
    'policy': new_policy,
    'critic': new_critic,
    'loss_module': new_loss_module,
    'optimizer': new_optimizer,
}
torch.save(new_model, "/home/hieule/research/saris_revised/local_assets/Hallway_L/hallway_focus_2_agents/checkpoint_187_agents.pt")

## Get new model for 4 agents

In [9]:
filename = "/home/hieule/research/saris_revised/local_assets/Hallway_L/hallway_focus_3_agents/checkpoint_187.pt"
model = torch.load(filename)

In [10]:
model.keys()

dict_keys(['policy', 'critic', 'loss_module', 'optimizer'])

In [11]:
policy = model['policy']
policy.keys()
new_policy = copy.deepcopy(policy)
for k, v in policy.items():
    if '__' not in k:
        new_policy[k] = torch.cat((v[:2], v[1:].clone()), dim=0)  # Concatenate the first two agents' weights and biases with the cloned second agent
        print(k, new_policy[k].shape)
    elif "batch_size" in k:
        new_policy[k] = torch.Size([4])  # Adjust batch size for the new policy
        print(k, new_policy[k])
    else:
        new_policy[k] = v
        print(k, new_policy[k])

module.0.module.0.params.0.weight torch.Size([4, 256, 9])
module.0.module.0.params.0.bias torch.Size([4, 256])
module.0.module.0.params.2.weight torch.Size([4, 256, 256])
module.0.module.0.params.2.bias torch.Size([4, 256])
module.0.module.0.params.4.weight torch.Size([4, 6, 256])
module.0.module.0.params.4.bias torch.Size([4, 6])
module.0.module.0.params.__batch_size torch.Size([4])
module.0.module.0.params.__device None


In [12]:
critic = model['critic']
critic.keys()
new_critic = copy.deepcopy(critic)
for k, v in critic.items():
    if '__' not in k:
        new_critic[k] = torch.cat((v[:2], v[1:].clone()), dim=0)  # Concatenate the first two agents' weights and biases
        print(k, new_critic[k].shape)
    elif "batch_size" in k:
        new_critic[k] = torch.Size([2])  # Adjust batch size for the new critic
        print(k, new_critic[k])
    else:
        new_critic[k] = v
        print(k, new_critic[k])

module.params.0.weight torch.Size([4, 256, 27])
module.params.0.bias torch.Size([4, 256])
module.params.2.weight torch.Size([4, 256, 256])
module.params.2.bias torch.Size([4, 256])
module.params.4.weight torch.Size([4, 1, 256])
module.params.4.bias torch.Size([4, 1])
module.params.__batch_size torch.Size([2])
module.params.__device None


In [13]:
loss_module = model['loss_module']
loss_module.keys()
new_loss_module = copy.deepcopy(loss_module)
for k, v in loss_module.items():
    if '__' not in k and 'params' in k:
        new_loss_module[k] = torch.cat((v[:2], v[1:].clone()), dim=0)  # Concatenate the first two agents' weights and biases
        print(k, new_loss_module[k].shape)
    elif "batch_size" in k:
        new_loss_module[k] = torch.Size([2])  # Adjust batch size for the new loss_module
        print(k, new_loss_module[k])
    else:
        new_loss_module[k] = v
        print(k, new_loss_module[k])

entropy_coef tensor(1.0000e-04, device='cuda:0')
critic_coef tensor(1., device='cuda:0')
clip_epsilon tensor(0.2000, device='cuda:0')
actor_network_params.module.0.module.0.params.0.weight torch.Size([4, 256, 9])
actor_network_params.module.0.module.0.params.0.bias torch.Size([4, 256])
actor_network_params.module.0.module.0.params.2.weight torch.Size([4, 256, 256])
actor_network_params.module.0.module.0.params.2.bias torch.Size([4, 256])
actor_network_params.module.0.module.0.params.4.weight torch.Size([4, 6, 256])
actor_network_params.module.0.module.0.params.4.bias torch.Size([4, 6])
actor_network_params.__batch_size torch.Size([2])
actor_network_params.__device None
critic_network_params.module.params.0.weight torch.Size([4, 256, 27])
critic_network_params.module.params.0.bias torch.Size([4, 256])
critic_network_params.module.params.2.weight torch.Size([4, 256, 256])
critic_network_params.module.params.2.bias torch.Size([4, 256])
critic_network_params.module.params.4.weight torch.Si

In [14]:
optimizer = model['optimizer']
optimizer.keys()
new_optimizer = copy.deepcopy(optimizer)
for k, state_idx in new_optimizer['state'].items():
    state_idx['exp_avg'] = torch.cat((state_idx['exp_avg'][:2], state_idx['exp_avg'][1:].clone()), dim=0)  # Concatenate the first two agents' exp_avg with the cloned second agent's exp_avg
    state_idx['exp_avg_sq'] = torch.cat((state_idx['exp_avg_sq'][:2], state_idx['exp_avg_sq'][1:].clone()), dim=0)  # Concatenate the first two agents' exp_avg_sq with the cloned second agent's exp_avg_sq
for k, state_idx in new_optimizer['state'].items():
    print(k, state_idx['exp_avg'].shape, state_idx['exp_avg_sq'].shape)

0 torch.Size([4, 256, 9]) torch.Size([4, 256, 9])
1 torch.Size([4, 256]) torch.Size([4, 256])
2 torch.Size([4, 256, 256]) torch.Size([4, 256, 256])
3 torch.Size([4, 256]) torch.Size([4, 256])
4 torch.Size([4, 6, 256]) torch.Size([4, 6, 256])
5 torch.Size([4, 6]) torch.Size([4, 6])
6 torch.Size([4, 256, 27]) torch.Size([4, 256, 27])
7 torch.Size([4, 256]) torch.Size([4, 256])
8 torch.Size([4, 256, 256]) torch.Size([4, 256, 256])
9 torch.Size([4, 256]) torch.Size([4, 256])
10 torch.Size([4, 1, 256]) torch.Size([4, 1, 256])
11 torch.Size([4, 1]) torch.Size([4, 1])


In [15]:
new_optimizer['param_groups']

[{'lr': 0.0002,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'differentiable': False,
  'fused': None,
  'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}]

In [16]:
new_model = {
    'policy': new_policy,
    'critic': new_critic,
    'loss_module': new_loss_module,
    'optimizer': new_optimizer,
}
torch.save(new_model, "/home/hieule/research/saris_revised/local_assets/Hallway_L/hallway_focus_4_agents/checkpoint_187_agents.pt")