In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

In [5]:
device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [26]:
# Verify that loading works, for one example

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

print(clean_input)
print(corrupted_input)
print(labels)

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

['The mother tongue of Danielle Darrieux is']
['The mother tongue of Paul McCartney is']
tensor([[24111, 15823]])
Clean logit difference: tensor([0.1160], device='cuda:0', grad_fn=<SubBackward0>)
Corrupted logit difference: tensor([-1.1990], device='cuda:0', grad_fn=<SubBackward0>)


In [25]:
n_epochs = 5

original_tokens = model.to_tokens(clean_input)
rewrite_tokens = model.to_tokens(corrupted_input)

original_logits, original_cache = model.run_with_cache(original_tokens)
original_logit_diff = logit_diff_metric(original_logits, labels)

rewrite_logits, rewrite_cache = model.run_with_cache(rewrite_tokens)
rewrite_logit_diff = logit_diff_metric(rewrite_logits, labels)

# LOCALISATION STAGE

mlp_highlighted, attn_highlighted = run_attribution_steps(
    model,
    original_tokens,
    rewrite_tokens,
    labels,
    original_cache,
    rewrite_cache,
    original_logit_diff,
    rewrite_logit_diff,
)

target_mlp = identify_target_components(model, mlp_highlighted)
target_attn = identify_target_components(model, attn_highlighted)

# EDITING STAGE

relevant_parameters = [
    p for name, p in model.named_parameters() if "attn" in name or "mlp" in name
]
optimiser = optim.Adam(relevant_parameters, lr=2e-4)

for _ in range(n_epochs):
    logits = model(original_tokens)
    answer_index = labels[:, 1]  # Aim for rewritten answer
    optimise_edit_components(
        model, logits, answer_index, target_mlp, target_attn, optimiser
    )

AssertionError: Baseline can be provided as a tensor for just one input and broadcasted to the batch or input and baseline must have the same shape or the baseline corresponding to each input tensor must be a scalar. Found baseline: tensor([[[[ 6.3183e-02,  5.9828e-02, -1.8017e-02,  ..., -3.1300e-02,
            1.3576e-01,  1.0932e-01],
          [ 1.8581e-01,  2.3438e-01, -2.4526e-01,  ..., -5.3240e-01,
           -3.8492e-01,  2.4711e-01],
          [-4.0007e-02, -1.3616e-01,  3.8954e-01,  ..., -6.7756e-02,
            1.2671e-01, -1.8432e-01],
          ...,
          [-2.8296e-01, -5.0378e-01,  1.6165e-01,  ..., -4.9117e-02,
           -1.7141e-01, -1.1625e-01],
          [-5.6480e-02, -4.1342e-01, -2.0710e-01,  ...,  1.9242e-01,
            2.5598e-01, -7.1256e-02],
          [ 1.5511e-01, -5.1378e-01, -1.7323e-01,  ...,  1.5156e-01,
           -3.0645e-01,  3.7273e-01]],

         [[ 5.6135e-02,  5.2765e-02,  8.7760e-03,  ..., -4.0619e-02,
            1.3522e-01,  9.8149e-02],
          [ 2.6660e-01,  6.5456e-03,  2.5709e-02,  ...,  1.8073e-01,
            1.1111e-01,  1.3402e-01],
          [-4.8208e-02, -1.4902e-01,  3.7220e-01,  ..., -6.3423e-02,
            1.3125e-01, -1.6714e-01],
          ...,
          [-3.3139e-01, -4.4216e-01,  1.5671e-01,  ..., -5.7046e-02,
           -1.5611e-01, -1.1741e-01],
          [-1.2947e-01, -3.6367e-01, -2.0415e-01,  ...,  1.7347e-01,
            2.6245e-01, -3.8326e-02],
          [ 1.0529e-01, -3.8342e-01, -1.3472e-01,  ...,  1.8566e-01,
           -1.9970e-01,  3.2450e-01]],

         [[ 2.2397e-02,  4.0383e-02,  7.2053e-02,  ..., -1.7661e-02,
            1.0304e-01,  4.1182e-02],
          [ 1.6782e-01, -1.9282e-01,  9.4300e-02,  ...,  2.8860e-01,
           -6.3679e-02,  1.7539e-01],
          [-6.8276e-02, -1.6953e-01,  3.4273e-01,  ..., -7.1643e-02,
            1.4635e-01, -1.4183e-01],
          ...,
          [-3.3865e-01, -4.1659e-01,  1.5499e-01,  ..., -5.7295e-02,
           -1.5056e-01, -1.0963e-01],
          [-1.2201e-01, -4.4275e-01, -1.8614e-01,  ...,  7.5768e-02,
            2.1584e-01, -5.0138e-03],
          [ 8.6065e-02, -2.1674e-01, -1.0654e-01,  ...,  1.5300e-01,
           -1.4707e-01,  3.2081e-01]],

         ...,

         [[ 6.4724e-02, -4.8898e-03,  1.1926e-01,  ..., -3.0955e-02,
            4.4494e-02,  5.0276e-02],
          [-2.5488e-01, -1.5862e-01, -3.1766e-01,  ...,  1.9324e-01,
           -2.3306e-01, -2.6534e-02],
          [ 2.0906e-02, -1.1226e-01,  3.4931e-01,  ..., -2.2410e-02,
            9.0138e-02, -1.0550e-01],
          ...,
          [-2.8575e-01, -2.9427e-01,  1.1049e-01,  ..., -8.5844e-03,
           -6.3982e-02, -7.3881e-02],
          [-8.0222e-02, -3.0836e-01, -9.6145e-02,  ...,  8.1953e-02,
           -7.9909e-02, -3.2124e-02],
          [ 8.8678e-02, -3.5062e-02, -9.9833e-02,  ...,  1.2702e-01,
           -5.0877e-03,  3.4614e-01]],

         [[ 8.0942e-02, -1.7746e-02,  1.2092e-01,  ..., -2.1856e-02,
            4.9897e-02,  2.3710e-02],
          [ 1.8996e-01,  2.1745e-01, -3.2437e-02,  ..., -2.5151e-02,
           -9.7690e-02,  1.1364e-01],
          [ 1.7757e-02, -8.9753e-02,  3.3869e-01,  ..., -4.9914e-02,
            1.1590e-01, -9.6719e-02],
          ...,
          [-3.2242e-01, -2.5883e-01,  1.4814e-01,  ..., -9.8169e-03,
           -1.3618e-01, -8.9883e-02],
          [-1.2651e-01, -2.9261e-01, -1.3029e-01,  ..., -2.2379e-02,
            4.8556e-03, -3.2084e-02],
          [ 8.3148e-02, -1.4907e-01, -1.0640e-01,  ...,  1.2983e-01,
           -2.8610e-02,  3.3631e-01]],

         [[ 7.1720e-02,  6.8320e-03,  9.2763e-02,  ..., -3.3083e-02,
            8.7794e-02,  3.8504e-02],
          [ 8.2080e-02,  1.0938e-01, -6.7033e-02,  ...,  1.2012e-02,
            1.2315e-01,  2.2886e-01],
          [-3.1729e-02, -1.3783e-01,  3.9982e-01,  ..., -5.1877e-02,
            1.4156e-01, -1.1621e-01],
          ...,
          [-3.2167e-01, -3.6756e-01,  1.6634e-01,  ...,  1.1999e-02,
           -1.3713e-01, -1.2182e-01],
          [-1.4868e-01, -3.5158e-01, -1.6904e-01,  ...,  6.7471e-02,
            1.5907e-01, -3.8502e-03],
          [ 1.4103e-01, -1.4993e-01, -1.8868e-01,  ...,  1.8337e-01,
           -5.2561e-02,  3.9406e-01]]],


        [[[ 6.3183e-02,  5.9828e-02, -1.8017e-02,  ..., -3.1300e-02,
            1.3576e-01,  1.0932e-01],
          [ 1.8581e-01,  2.3438e-01, -2.4526e-01,  ..., -5.3240e-01,
           -3.8492e-01,  2.4711e-01],
          [-4.0007e-02, -1.3616e-01,  3.8954e-01,  ..., -6.7756e-02,
            1.2671e-01, -1.8432e-01],
          ...,
          [-2.8296e-01, -5.0378e-01,  1.6165e-01,  ..., -4.9117e-02,
           -1.7141e-01, -1.1625e-01],
          [-5.6480e-02, -4.1342e-01, -2.0710e-01,  ...,  1.9242e-01,
            2.5598e-01, -7.1256e-02],
          [ 1.5511e-01, -5.1378e-01, -1.7323e-01,  ...,  1.5156e-01,
           -3.0645e-01,  3.7273e-01]],

         [[ 5.6135e-02,  5.2765e-02,  8.7760e-03,  ..., -4.0619e-02,
            1.3522e-01,  9.8149e-02],
          [ 2.6660e-01,  6.5456e-03,  2.5709e-02,  ...,  1.8073e-01,
            1.1111e-01,  1.3402e-01],
          [-4.8208e-02, -1.4902e-01,  3.7220e-01,  ..., -6.3423e-02,
            1.3125e-01, -1.6714e-01],
          ...,
          [-3.3139e-01, -4.4216e-01,  1.5671e-01,  ..., -5.7046e-02,
           -1.5611e-01, -1.1741e-01],
          [-1.2947e-01, -3.6367e-01, -2.0415e-01,  ...,  1.7347e-01,
            2.6245e-01, -3.8326e-02],
          [ 1.0529e-01, -3.8342e-01, -1.3472e-01,  ...,  1.8566e-01,
           -1.9970e-01,  3.2450e-01]],

         [[ 3.7676e-02,  4.1196e-02,  6.9550e-02,  ..., -6.6748e-02,
            1.2939e-01,  5.9925e-02],
          [-1.6487e-02, -2.1374e-01, -2.9048e-02,  ...,  2.6674e-01,
            1.6551e-01,  1.5811e-01],
          [-5.7450e-02, -1.6644e-01,  3.4600e-01,  ..., -5.3515e-02,
            1.3963e-01, -1.2742e-01],
          ...,
          [-3.3570e-01, -3.8015e-01,  1.6264e-01,  ..., -4.4775e-02,
           -1.5235e-01, -8.9518e-02],
          [-2.0001e-02, -2.8253e-01, -1.5684e-01,  ...,  1.0784e-01,
            1.1128e-01, -6.7402e-02],
          [ 1.2750e-01, -3.3137e-01, -1.7557e-01,  ...,  1.7779e-01,
           -1.4017e-01,  3.3005e-01]],

         ...,

         [[ 4.9619e-02, -1.0999e-02,  7.5101e-02,  ..., -1.2517e-01,
            4.7757e-02,  3.9624e-02],
          [ 2.5300e-03,  1.7440e-01, -9.4192e-03,  ...,  3.1506e-02,
            2.5101e-02,  2.2121e-01],
          [ 6.8708e-03, -1.3345e-01,  3.1822e-01,  ..., -5.4231e-02,
            1.2632e-01, -8.2346e-02],
          ...,
          [-3.2515e-01, -2.6622e-01,  1.2463e-01,  ..., -8.7309e-03,
           -1.3990e-01, -9.7888e-02],
          [ 3.2924e-03, -1.5808e-01, -1.0563e-01,  ...,  6.1098e-02,
           -1.1146e-02, -7.0645e-02],
          [ 1.1333e-01, -1.2933e-01, -1.1946e-01,  ...,  1.3591e-01,
           -1.6498e-02,  3.5188e-01]],

         [[ 4.3119e-02, -5.9800e-02,  1.1382e-01,  ..., -1.7368e-01,
            1.0856e-01,  1.3357e-02],
          [-3.0208e-01, -1.3085e-01,  2.6735e-01,  ...,  2.2674e-01,
           -1.0058e-01, -2.0607e-01],
          [ 4.3683e-03, -1.5895e-01,  2.7759e-01,  ..., -5.7834e-02,
            1.2631e-01, -1.1445e-01],
          ...,
          [-3.1207e-01, -3.2192e-01,  1.2122e-01,  ..., -3.1067e-03,
           -1.3828e-01, -1.0901e-01],
          [-6.9010e-02, -3.4290e-01, -3.9674e-02,  ...,  7.8374e-02,
            5.5379e-03,  2.7121e-02],
          [ 9.8452e-02, -6.0071e-02, -1.3188e-01,  ...,  1.3406e-01,
            6.8523e-02,  3.4852e-01]],

         [[ 6.4684e-02, -3.2300e-02,  7.5071e-02,  ..., -1.1745e-01,
            9.3861e-02,  2.1488e-02],
          [ 1.6433e-01,  2.1547e-01, -4.0457e-02,  ..., -2.1143e-02,
           -1.0321e-01,  1.1648e-01],
          [ 5.5406e-02, -9.5885e-02,  3.2056e-01,  ..., -5.6154e-02,
            1.2361e-01, -1.0570e-01],
          ...,
          [-3.4844e-01, -2.5039e-01,  1.3747e-01,  ...,  7.9164e-05,
           -1.7358e-01, -1.0579e-01],
          [-1.0557e-01, -2.3877e-01, -1.4631e-01,  ..., -3.3431e-02,
           -5.0332e-02, -5.7785e-02],
          [ 1.0105e-01, -1.6895e-01, -1.2693e-01,  ...,  1.4249e-01,
            1.3330e-02,  3.3134e-01]]],


        [[[ 6.3183e-02,  5.9828e-02, -1.8017e-02,  ..., -3.1300e-02,
            1.3576e-01,  1.0932e-01],
          [ 1.8581e-01,  2.3438e-01, -2.4526e-01,  ..., -5.3240e-01,
           -3.8492e-01,  2.4711e-01],
          [-4.0007e-02, -1.3616e-01,  3.8954e-01,  ..., -6.7756e-02,
            1.2671e-01, -1.8432e-01],
          ...,
          [-2.8296e-01, -5.0378e-01,  1.6165e-01,  ..., -4.9117e-02,
           -1.7141e-01, -1.1625e-01],
          [-5.6480e-02, -4.1342e-01, -2.0710e-01,  ...,  1.9242e-01,
            2.5598e-01, -7.1256e-02],
          [ 1.5511e-01, -5.1378e-01, -1.7323e-01,  ...,  1.5156e-01,
           -3.0645e-01,  3.7273e-01]],

         [[ 6.1459e-02,  6.2121e-02, -2.8548e-02,  ..., -3.7793e-02,
            1.2494e-01,  1.0063e-01],
          [ 2.0101e-01, -5.9048e-02, -2.1680e-02,  ...,  4.5949e-01,
            1.9642e-01,  4.2402e-02],
          [-4.5368e-02, -1.2999e-01,  3.5952e-01,  ..., -6.8466e-02,
            1.1048e-01, -1.7313e-01],
          ...,
          [-3.2417e-01, -4.4840e-01,  1.6240e-01,  ..., -1.5891e-02,
           -1.4142e-01, -7.4223e-02],
          [-1.0908e-01, -3.6107e-01, -1.7821e-01,  ...,  1.3265e-01,
            2.3327e-01, -6.7998e-02],
          [ 1.1848e-01, -4.2342e-01, -1.5293e-01,  ...,  1.5777e-01,
           -2.0917e-01,  3.3538e-01]],

         [[ 4.8368e-02,  7.0170e-02, -5.6113e-02,  ..., -5.3152e-02,
            8.0512e-02,  8.4481e-02],
          [-2.2859e-01,  2.1869e-02,  3.6398e-02,  ...,  8.5735e-02,
            1.3093e-01, -8.4808e-02],
          [-3.9078e-02, -1.2918e-01,  3.2743e-01,  ..., -6.1834e-02,
            1.0285e-01, -1.5874e-01],
          ...,
          [-3.2902e-01, -4.3611e-01,  1.6472e-01,  ..., -1.7701e-02,
           -1.3316e-01, -6.8558e-02],
          [-4.9892e-02, -3.5500e-01, -1.4411e-01,  ...,  1.3653e-01,
            1.3581e-01, -1.0052e-01],
          [ 1.1163e-01, -3.6817e-01, -1.5165e-01,  ...,  1.5272e-01,
           -1.7209e-01,  3.2340e-01]],

         ...,

         [[ 6.1552e-02,  9.7966e-02, -6.2787e-02,  ..., -4.0116e-02,
            1.4242e-01,  8.2613e-02],
          [ 1.3697e-01,  1.0831e-01, -5.7050e-02,  ...,  2.3917e-02,
            1.4825e-01,  2.2475e-01],
          [-1.2342e-01, -1.4249e-01,  3.8880e-01,  ..., -7.2036e-02,
            1.1983e-01, -1.4325e-01],
          ...,
          [-3.1303e-01, -4.9957e-01,  2.1792e-01,  ...,  6.5372e-02,
           -1.3064e-01, -9.1173e-02],
          [-1.6771e-01, -3.6374e-01, -1.9052e-01,  ...,  1.4985e-01,
            2.3929e-01, -2.6730e-02],
          [ 2.2840e-01, -3.4029e-01, -3.3357e-01,  ...,  2.9052e-01,
           -2.0867e-01,  4.4150e-01]],

         [[ 5.9767e-02,  1.0052e-01, -7.7298e-02,  ..., -4.1897e-02,
            1.7159e-01,  9.3002e-02],
          [ 1.2573e-01,  1.0747e-01, -5.5569e-02,  ...,  2.7017e-02,
            1.5030e-01,  2.2532e-01],
          [-1.3033e-01, -1.5881e-01,  4.3875e-01,  ..., -6.9282e-02,
            1.3641e-01, -1.6496e-01],
          ...,
          [-3.1057e-01, -5.1790e-01,  2.1481e-01,  ...,  6.9270e-02,
           -1.2908e-01, -1.0996e-01],
          [-1.8284e-01, -3.7971e-01, -2.0292e-01,  ...,  1.3540e-01,
            2.6760e-01, -1.1640e-02],
          [ 2.5303e-01, -3.5602e-01, -3.6007e-01,  ...,  3.0982e-01,
           -2.3017e-01,  4.6528e-01]],

         [[ 5.7735e-02,  1.0278e-01, -8.5991e-02,  ..., -4.1738e-02,
            1.9132e-01,  1.0020e-01],
          [ 1.1321e-01,  1.0665e-01, -5.6380e-02,  ...,  2.9411e-02,
            1.4955e-01,  2.2620e-01],
          [-1.3577e-01, -1.7232e-01,  4.7817e-01,  ..., -6.6842e-02,
            1.4873e-01, -1.8249e-01],
          ...,
          [-3.0780e-01, -5.3170e-01,  2.1181e-01,  ...,  7.2728e-02,
           -1.2800e-01, -1.2404e-01],
          [-1.9340e-01, -3.9288e-01, -2.1265e-01,  ...,  1.2336e-01,
            2.8932e-01,  1.7411e-04],
          [ 2.7039e-01, -3.6830e-01, -3.7865e-01,  ...,  3.2212e-01,
           -2.4867e-01,  4.8173e-01]]]], device='cuda:0') and input: tensor([[[[ 0.0632,  0.0598, -0.0180,  ..., -0.0313,  0.1358,  0.1093],
          [ 0.1858,  0.2344, -0.2453,  ..., -0.5324, -0.3849,  0.2471],
          [-0.0400, -0.1362,  0.3895,  ..., -0.0678,  0.1267, -0.1843],
          ...,
          [-0.2830, -0.5038,  0.1617,  ..., -0.0491, -0.1714, -0.1162],
          [-0.0565, -0.4134, -0.2071,  ...,  0.1924,  0.2560, -0.0713],
          [ 0.1551, -0.5138, -0.1732,  ...,  0.1516, -0.3064,  0.3727]],

         [[ 0.0561,  0.0528,  0.0088,  ..., -0.0406,  0.1352,  0.0981],
          [ 0.2666,  0.0065,  0.0257,  ...,  0.1807,  0.1111,  0.1340],
          [-0.0482, -0.1490,  0.3722,  ..., -0.0634,  0.1313, -0.1671],
          ...,
          [-0.3314, -0.4422,  0.1567,  ..., -0.0570, -0.1561, -0.1174],
          [-0.1295, -0.3637, -0.2042,  ...,  0.1735,  0.2624, -0.0383],
          [ 0.1053, -0.3834, -0.1347,  ...,  0.1857, -0.1997,  0.3245]],

         [[ 0.0224,  0.0404,  0.0721,  ..., -0.0177,  0.1030,  0.0412],
          [ 0.1678, -0.1928,  0.0943,  ...,  0.2886, -0.0637,  0.1754],
          [-0.0683, -0.1695,  0.3427,  ..., -0.0716,  0.1464, -0.1418],
          ...,
          [-0.3386, -0.4166,  0.1550,  ..., -0.0573, -0.1506, -0.1096],
          [-0.1220, -0.4427, -0.1861,  ...,  0.0758,  0.2158, -0.0050],
          [ 0.0861, -0.2167, -0.1065,  ...,  0.1530, -0.1471,  0.3208]],

         ...,

         [[ 0.0390, -0.0059,  0.0447,  ..., -0.0441,  0.1086,  0.0771],
          [ 0.1154,  0.1087, -0.0581,  ...,  0.0240,  0.1421,  0.2260],
          [-0.0464, -0.1620,  0.4461,  ..., -0.0695,  0.1694, -0.1567],
          ...,
          [-0.3399, -0.4017,  0.1574,  ...,  0.0261, -0.1458, -0.1475],
          [-0.1677, -0.3812, -0.2123,  ...,  0.0474,  0.2063, -0.0148],
          [ 0.2072, -0.2253, -0.2902,  ...,  0.2524, -0.1484,  0.4360]],

         [[ 0.0401,  0.0144,  0.0147,  ..., -0.0430,  0.1390,  0.0875],
          [ 0.1022,  0.1072, -0.0583,  ...,  0.0282,  0.1439,  0.2270],
          [-0.0644, -0.1743,  0.4826,  ..., -0.0669,  0.1756, -0.1744],
          ...,
          [-0.3308, -0.4403,  0.1651,  ...,  0.0391, -0.1413, -0.1550],
          [-0.1791, -0.3947, -0.2208,  ...,  0.0462,  0.2348, -0.0043],
          [ 0.2339, -0.2616, -0.3218,  ...,  0.2757, -0.1817,  0.4593]],

         [[ 0.0416,  0.0292, -0.0095,  ..., -0.0427,  0.1646,  0.0950],
          [ 0.0877,  0.1075, -0.0574,  ...,  0.0300,  0.1454,  0.2279],
          [-0.0788, -0.1834,  0.5108,  ..., -0.0656,  0.1801, -0.1887],
          ...,
          [-0.3247, -0.4666,  0.1696,  ...,  0.0464, -0.1377, -0.1605],
          [-0.1885, -0.4046, -0.2280,  ...,  0.0453,  0.2583,  0.0038],
          [ 0.2535, -0.2840, -0.3431,  ...,  0.2916, -0.2048,  0.4742]]],


        [[[ 0.0632,  0.0598, -0.0180,  ..., -0.0313,  0.1358,  0.1093],
          [ 0.1858,  0.2344, -0.2453,  ..., -0.5324, -0.3849,  0.2471],
          [-0.0400, -0.1362,  0.3895,  ..., -0.0678,  0.1267, -0.1843],
          ...,
          [-0.2830, -0.5038,  0.1617,  ..., -0.0491, -0.1714, -0.1162],
          [-0.0565, -0.4134, -0.2071,  ...,  0.1924,  0.2560, -0.0713],
          [ 0.1551, -0.5138, -0.1732,  ...,  0.1516, -0.3064,  0.3727]],

         [[ 0.0561,  0.0528,  0.0088,  ..., -0.0406,  0.1352,  0.0981],
          [ 0.2666,  0.0065,  0.0257,  ...,  0.1807,  0.1111,  0.1340],
          [-0.0482, -0.1490,  0.3722,  ..., -0.0634,  0.1313, -0.1671],
          ...,
          [-0.3314, -0.4422,  0.1567,  ..., -0.0570, -0.1561, -0.1174],
          [-0.1295, -0.3637, -0.2042,  ...,  0.1735,  0.2624, -0.0383],
          [ 0.1053, -0.3834, -0.1347,  ...,  0.1857, -0.1997,  0.3245]],

         [[ 0.0377,  0.0412,  0.0695,  ..., -0.0667,  0.1294,  0.0599],
          [-0.0165, -0.2137, -0.0290,  ...,  0.2667,  0.1655,  0.1581],
          [-0.0575, -0.1664,  0.3460,  ..., -0.0535,  0.1396, -0.1274],
          ...,
          [-0.3357, -0.3801,  0.1626,  ..., -0.0448, -0.1524, -0.0895],
          [-0.0200, -0.2825, -0.1568,  ...,  0.1078,  0.1113, -0.0674],
          [ 0.1275, -0.3314, -0.1756,  ...,  0.1778, -0.1402,  0.3300]],

         ...,

         [[-0.0097,  0.0493,  0.0675,  ...,  0.0189, -0.0203,  0.0910],
          [ 0.0876,  0.2347, -0.3979,  ..., -0.0596,  0.1815,  0.0359],
          [ 0.0124, -0.0730,  0.2740,  ..., -0.0552,  0.0996, -0.0772],
          ...,
          [-0.3377, -0.2410,  0.1608,  ..., -0.0502, -0.0979, -0.0935],
          [-0.1080, -0.1970, -0.1752,  ...,  0.0729,  0.0166, -0.0993],
          [ 0.0906, -0.1237, -0.1725,  ...,  0.1998, -0.0151,  0.3163]],

         [[-0.0240,  0.0152,  0.0722,  ..., -0.0323,  0.0153,  0.0777],
          [-0.0253,  0.1407, -0.1750,  ...,  0.0721, -0.2685, -0.1088],
          [ 0.0451, -0.0695,  0.2314,  ..., -0.0395,  0.0745, -0.0457],
          ...,
          [-0.3236, -0.2250,  0.1637,  ..., -0.0363, -0.0888, -0.0859],
          [-0.0778, -0.1826, -0.1689,  ...,  0.0864, -0.0321, -0.0719],
          [ 0.0875, -0.1289, -0.1605,  ...,  0.1928, -0.0261,  0.3056]],

         [[ 0.0227,  0.0337,  0.0462,  ..., -0.0308,  0.0519,  0.0793],
          [ 0.1303,  0.2213, -0.0339,  ..., -0.0225, -0.1041,  0.1170],
          [ 0.0395, -0.1052,  0.2418,  ..., -0.0531,  0.0944, -0.0917],
          ...,
          [-0.3438, -0.1803,  0.1702,  ..., -0.0141, -0.1516, -0.0745],
          [-0.1168, -0.2029, -0.1669,  ..., -0.0287, -0.0602, -0.0956],
          [ 0.0929, -0.1973, -0.1562,  ...,  0.1743, -0.0282,  0.3170]]],


        [[[ 0.0632,  0.0598, -0.0180,  ..., -0.0313,  0.1358,  0.1093],
          [ 0.1858,  0.2344, -0.2453,  ..., -0.5324, -0.3849,  0.2471],
          [-0.0400, -0.1362,  0.3895,  ..., -0.0678,  0.1267, -0.1843],
          ...,
          [-0.2830, -0.5038,  0.1617,  ..., -0.0491, -0.1714, -0.1162],
          [-0.0565, -0.4134, -0.2071,  ...,  0.1924,  0.2560, -0.0713],
          [ 0.1551, -0.5138, -0.1732,  ...,  0.1516, -0.3064,  0.3727]],

         [[ 0.0626,  0.0688, -0.0081,  ..., -0.0430,  0.1256,  0.0931],
          [ 0.0056, -0.0056, -0.0125,  ...,  0.3735, -0.0373,  0.1431],
          [-0.0324, -0.1349,  0.3851,  ..., -0.0637,  0.1241, -0.1818],
          ...,
          [-0.3141, -0.4468,  0.1448,  ..., -0.0364, -0.1514, -0.1081],
          [-0.0845, -0.3957, -0.1217,  ...,  0.1417,  0.2262, -0.0501],
          [ 0.1276, -0.4066, -0.1258,  ...,  0.2022, -0.2168,  0.3499]],

         [[ 0.0538,  0.0695,  0.0071,  ..., -0.0680,  0.0800,  0.0669],
          [ 0.1445, -0.0442, -0.1842,  ...,  0.3799,  0.1650,  0.0277],
          [-0.0213, -0.1298,  0.3684,  ..., -0.0589,  0.1145, -0.1673],
          ...,
          [-0.3241, -0.3642,  0.1712,  ..., -0.0419, -0.1588, -0.1205],
          [-0.0851, -0.3901, -0.1720,  ...,  0.0625,  0.1336, -0.0081],
          [ 0.1281, -0.4012, -0.1481,  ...,  0.1854, -0.2014,  0.3665]],

         ...,

         [[ 0.0579,  0.0904, -0.0220,  ..., -0.0799,  0.1815,  0.0354],
          [ 0.1256,  0.1074, -0.0555,  ...,  0.0274,  0.1500,  0.2253],
          [-0.1137, -0.1721,  0.4689,  ..., -0.0657,  0.1387, -0.1832],
          ...,
          [-0.3110, -0.5115,  0.1993,  ...,  0.0593, -0.1486, -0.1371],
          [-0.1789, -0.3888, -0.1912,  ...,  0.1167,  0.2705, -0.0022],
          [ 0.2536, -0.3351, -0.3439,  ...,  0.3073, -0.2030,  0.4756]],

         [[ 0.0567,  0.0939, -0.0378,  ..., -0.0741,  0.1982,  0.0499],
          [ 0.1131,  0.1066, -0.0563,  ...,  0.0297,  0.1494,  0.2262],
          [-0.1213, -0.1830,  0.5026,  ..., -0.0639,  0.1501, -0.1974],
          ...,
          [-0.3082, -0.5263,  0.1989,  ...,  0.0647, -0.1440, -0.1461],
          [-0.1899, -0.4000, -0.2034,  ...,  0.1079,  0.2919,  0.0076],
          [ 0.2701, -0.3495, -0.3639,  ...,  0.3195, -0.2242,  0.4902]],

         [[ 0.0553,  0.0963, -0.0522,  ..., -0.0694,  0.2126,  0.0628],
          [ 0.0992,  0.1069, -0.0560,  ...,  0.0309,  0.1494,  0.2271],
          [-0.1276, -0.1907,  0.5278,  ..., -0.0630,  0.1581, -0.2081],
          ...,
          [-0.3064, -0.5369,  0.1978,  ...,  0.0672, -0.1403, -0.1529],
          [-0.1987, -0.4083, -0.2131,  ...,  0.1004,  0.3091,  0.0149],
          [ 0.2830, -0.3579, -0.3774,  ...,  0.3279, -0.2381,  0.4997]]]],
       device='cuda:0')