In [3]:
import torch
from sft_eval import PolicyNet  # or import your model definition

checkpoint = torch.load("policy_sft.pt")

input_dim = checkpoint["input_dim"]
num_actions = checkpoint["num_actions"]
idx2action = checkpoint["idx2action"]

model = PolicyNet(input_dim, num_actions)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

PolicyNet(
  (net): Sequential(
    (0): Linear(in_features=6, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=521, bias=True)
  )
)

In [4]:
for name, param in model.named_parameters():
    print(name, param.shape)


net.0.weight torch.Size([256, 6])
net.0.bias torch.Size([256])
net.2.weight torch.Size([256, 256])
net.2.bias torch.Size([256])
net.4.weight torch.Size([521, 256])
net.4.bias torch.Size([521])


In [5]:
import pandas as pd

w = model.net[0].weight.detach().numpy()
df = pd.DataFrame(w, columns=["BRICK","WOOD","SHEEP","WHEAT","ORE","VP"])
print(df.head())


      BRICK      WOOD     SHEEP     WHEAT       ORE        VP
0  0.352846  0.390151 -0.138592  0.482652 -0.134897  0.107915
1 -0.281113  0.288690  0.431545 -0.318852  0.396663  0.094861
2  0.330936  0.092001  0.239117 -0.060689  0.345641  0.062946
3 -0.142741  0.047987 -0.310156  0.016095 -0.155314  0.324403
4 -0.322259 -0.188207 -0.115279 -0.245468  0.038532 -0.403218


In [6]:
final_w = model.net[-1].weight.detach().numpy()
final_b = model.net[-1].bias.detach().numpy()


In [7]:
import numpy as np

first = model.net[0].weight.detach().numpy()  # shape (256, 6)
second = model.net[2].weight.detach().numpy() # shape (256, 256)
final = model.net[4].weight.detach().numpy()  # shape (num_actions, 256)

# multiply absolute weights through network
importance = np.abs(final) @ np.abs(second) @ np.abs(first)

print("Importance shape:", importance.shape)


Importance shape: (521, 6)


In [9]:
len(idx2action)

521

In [8]:
for i, act in enumerate(idx2action):
    print(f"\nAction: {act}")
    sorted_idx = np.argsort(-importance[i])
    for j in sorted_idx:
        print(f"  {['BRICK','WOOD','SHEEP','WHEAT','ORE','VP'][j]}: {importance[i][j]:.4f}")



Action: ('Action', "Action(color=<Color.RED: 'RED'>, action_type=<ActionType.BUILD_SETTLEMENT: 'BUILD_SETTLEMENT'>, value=50)")
  WOOD: 46.2041
  WHEAT: 45.7152
  BRICK: 43.4060
  ORE: 43.0001
  SHEEP: 42.7056
  VP: 33.6435

Action: ('Action', "Action(color=<Color.RED: 'RED'>, action_type=<ActionType.BUILD_ROAD: 'BUILD_ROAD'>, value=(49, 50))")
  WOOD: 49.3756
  WHEAT: 49.0783
  ORE: 46.2981
  BRICK: 46.1651
  SHEEP: 45.6969
  VP: 36.1307

Action: ('Action', "Action(color=<Color.WHITE: 'WHITE'>, action_type=<ActionType.BUILD_SETTLEMENT: 'BUILD_SETTLEMENT'>, value=47)")
  WOOD: 44.2020
  WHEAT: 43.9649
  BRICK: 41.5528
  ORE: 41.2516
  SHEEP: 40.8789
  VP: 32.2332

Action: ('Action', "Action(color=<Color.WHITE: 'WHITE'>, action_type=<ActionType.BUILD_ROAD: 'BUILD_ROAD'>, value=(45, 47))")
  WOOD: 47.1575
  WHEAT: 47.0575
  BRICK: 44.4414
  ORE: 44.2139
  SHEEP: 43.8866
  VP: 34.4324

Action: ('Action', "Action(color=<Color.WHITE: 'WHITE'>, action_type=<ActionType.BUILD_SETTLEMENT: 'BUI

In [None]:
import matplotlib.pyplot as plt
plt.imshow(model.net[0].weight.detach().numpy(), aspect='auto')
plt.colorbar()
plt.show()
