In [1]:
"""
Pointer Immediate Tasks:
- Run an experiment where I take trained pointer networks and transition them to a dominoe based value 
function and a gamma < 1, and show that they can learn to prioritize playing high value dominoes first. Then...
-  Add the context vector that encodes the number of turns left (with uncertainty?)
    - so the full pointer network will get an extra context input that describes how many turns are left
    - 0 rewards will be given after the possible turns are over
    - so the network will have to learn to get as much value out as quickly as possible
-  Also apply these networks to the vehicle routing problem?
-  Analyze encoding space of pointer networks...
-  Do the encoder swap of different pointer layers...
-  Does the speed of learning for the different networks on the sequencer task come from true performance or just sensitivity to the temperature? 
"""

"""
We can speed up some of the processing by doing all networks output in a list at once-- for example for reward computation. I think 
the sequencer is slower because measuring reward takes a long time...

Add mechanism for printing the arguments used to build a pointernetwork so the user can see what they did. 
Add mechanism for storing hidden parameters to entire pointer network

Add documentation of baseline updates and performance etc
Add some dataset specific summary plots and integrate into plotting code? 
Get the supervised learning methods working for each dataset and task
Checkpointing, figure making, logging, etc

it worked!!! now trying without embedding bias...
it works without embedding bias. It works (with different speeds per pointer layer!) with lower train temperature
 (but of course that could be because of differential sensitivity to temperature..., should test that directly)
now trying with 1 encoding layer. 

:)

TODO: 
DOMINOES SEQUENCER Comparison of max to real reward:
- Add target to batch (can do post-hoc, even if not requested)
- Measure reward of target
- Add a 2D vector comparing max and real reward for each batch element!!

TSP Distance Traveled:
- Explicitly measure the distance traveled by the agent in the TSP task
- Compare to Held-Karp Solution


TODO ASAP!!!!!!
I'm running experiments with attention only pointers and saving the networks. Will use trained networks
to test some other dataset specific variables and plots, etc when they're finished, then integrate those
into the main workflow. 

TODO ASAP!!!!!!
Make a checkpoint from the trained & saved dominoe_sequencer results to continue training etc.
"""

# put differences in parameters in logger!!!
# think more carefully about how to handle the ignore index for reward computation.....
# add special messages for curricula phases

# handling of checkpoints for curriculum learning is kinda rough...
# -- I think I can make a central system for all curricula that will handle all of this

# I need a smarter "load_experiment" update prms / exp.args method... (for example probably shouldn't update the checkpointing parameters)
# And in general need better analysis tools for rebuilding networks and rerunning data processing

# consider updating masking methods (so that I can have a valid hook point after applying a mask)

# Interpretability Tools:
# https://github.com/TransformerLensOrg/TransformerLens?tab=readme-ov-file --- this has a really nice looking reading list...
# https://transformer-circuits.pub/2021/garcon/index.html

%reload_ext autoreload
%autoreload 2

# Imports
from time import time
from matplotlib import pyplot as plt
import torch
from ptrseq.experiments import get_experiment
from ptrseq.datasets import get_dataset
from ptrseq.experiments import get_experiment
from ptrseq.utils import get_scheduler, build_args, stack_results, get_dictionary_differences
from ptrseq.networks.net_utils import forward_batch

In [1]:
from ptrseq.networks.attention_modules import get_attention_layer
from ptrseq.networks.transformer_modules import get_transformer_layer

import torch
from transformer_lens.hook_points import HookedRootModule

transformer = get_transformer_layer(16, 4, 1, False, False, False, 0, False, False)

input = torch.rand(1, 2, 16)
out = transformer(input)
print(out)

out, cache = transformer.run_with_cache(input)
print(out)



tensor([[[ 0.4035,  0.9901,  1.3796, -1.2110, -1.1152, -0.2425,  0.5941,
          -1.3056,  0.0769, -1.4911,  0.2296,  2.1306,  0.7261, -0.6109,
           0.0776, -0.6318],
         [ 1.7410,  0.5438, -0.4212, -0.5027,  0.4249, -0.3977,  0.6491,
          -0.8524,  0.9294, -1.8993, -2.0965,  0.5857,  0.1425, -0.4267,
           1.0089,  0.5714]]], grad_fn=<NativeLayerNormBackward0>)
tensor([[[ 0.4035,  0.9901,  1.3796, -1.2110, -1.1152, -0.2425,  0.5941,
          -1.3056,  0.0769, -1.4911,  0.2296,  2.1306,  0.7261, -0.6109,
           0.0776, -0.6318],
         [ 1.7410,  0.5438, -0.4212, -0.5027,  0.4249, -0.3977,  0.6491,
          -0.8524,  0.9294, -1.8993, -2.0965,  0.5857,  0.1425, -0.4267,
           1.0089,  0.5714]]], grad_fn=<NativeLayerNormBackward0>)


In [4]:
for k, v in cache.items():
    print(k, v.shape)

attention._query_hook torch.Size([1, 2, 16])
attention._key_hook torch.Size([1, 2, 16])
attention._value_hook torch.Size([1, 2, 16])
attention._attention_hook torch.Size([4, 2, 2])
attention._head_output_hook torch.Size([1, 2, 4, 4])
attention._unify_head_hook torch.Size([1, 2, 16])
_attention_hook torch.Size([1, 2, 16])
_mlp_pre_hook torch.Size([1, 2, 16])
_mlp_post_hook torch.Size([1, 2, 16])


In [150]:
torch.set_grad_enabled(False)

experiment = "ptr_arch_comp"
args = dict(task="dominoe_sorter", encoder_method="attention", decoder_method="attention", replicates="3", embedding_bias="False")
exp = get_experiment(experiment, build=True, args=build_args(kvargs=args))

results = exp.load_experiment(use_saved_prms=False, verbose=True)
dataset = exp.prepare_dataset()

# input dimensionality
input_dim = dataset.get_input_dim()
context_parameters = dataset.get_context_parameters()

# create networks
nets, optimizers, prms = exp.create_networks(input_dim, context_parameters)
nets = exp.load_networks(nets)
for net in nets:
    net.eval()

# prepare dataset parameters
parameters = exp.make_train_parameters(dataset, train=False)
dominoes = dataset.get_dominoe_set(parameters["train"])

t = time()

# create example batch
batch = dataset.generate_batch(**(parameters | {"batch_size": 1024}))

# run data through the network
max_possible_output = parameters.get("max_possible_output")  # this is the maximum number of outputs ever
scores, choices = forward_batch(nets, batch, max_possible_output, temperature=1.0, thompson=False)

# measure rewards
rewards = [dataset.reward_function(choice, batch) for choice in choices]

print(time() - t)

1.280393123626709


In [146]:
# Example usage
from torch import nn
from copy import deepcopy
from contextlib import contextmanager

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        # self.conv1 = nn.Conv2d(1, 2, kernel_size=5)
        # self.conv2 = nn.Conv2d(2, 5, kernel_size=5)
        # self.conv = nn.Sequential(self.conv1, nn.ReLU(), self.conv2, nn.ReLU())
        # self.fc1 = nn.Linear(80, 10)
        # self.fc2 = nn.Linear(10, 2)
        # self.linear = nn.Sequential(self.fc1, nn.ReLU(), self.fc2)

        self.layers = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(2, 5, kernel_size=5),
            nn.ReLU(),
            nn.Flatten(start_dim=1),
            nn.Linear(80, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )

    def forward(self, x):
        return self.layers(x)
        # x = self.conv(x)
        # x = x.view(-1, 320)
        # x = self.linear(x)
        # return x
    
class HookedModel(nn.Module):
    def __init__(self, model):
        super(HookedModel, self).__init__()
        self.model = model
        self.cache = {}
        self._add_hooks()
        self.store_hidden = False

    def _add_hooks(self):
        self._layer_to_name = {}
        for name, layer in self.model.named_children():
            self._layer_to_name[layer] = name
            if True: #not isinstance(layer, nn.Sequential):
                layer.register_forward_hook(self._forward_hook)
    
    def _forward_hook(self, layer, input, output):
        if self.store_hidden:
            self.cache[self._layer_to_name[layer]] = output

    @contextmanager
    def _handle_cache(self, store_hidden):
        self.store_hidden = store_hidden
        try:
            yield
        finally:
            self.store_hidden = False

    def forward(self, x, store_hidden=False):
        with self._handle_cache(store_hidden):
            output = self.model(x)
        return output
    
net = CustomModel()
hnet = HookedModel(deepcopy(net))

In [164]:
def get_all_submodules(module, prefix=""):
    """
    Recursively collects all submodules of a given nn.Module.

    Args:
        module (nn.Module): The parent module.

    Returns:
        List[nn.Module]: A list of all submodules.
    """
    submodules = []
    for name, submodule in module.named_children():
        full_name = f"{prefix}.{name}" if prefix else name
        submodules.append(full_name)
        submodules.extend(get_all_submodules(submodule, prefix=full_name))
    return submodules

In [176]:
for idx, (name, module) in enumerate(nets[0].named_modules()):
    if idx==0: continue
    print(idx, name, module)
    print(module.__class__.__name__)
    break

1 embedding Linear(in_features=20, out_features=128, bias=False)
Linear


In [151]:
hnet = HookedModel(nets[0])

In [152]:
hnet(batch["input"], store_hidden=True)

(tensor([[[-6.4133e-05, -1.9241e+01, -1.9159e+01,  ..., -1.9241e+01,
           -1.9241e+01, -1.9241e+01],
          [-6.9578e+01, -1.9578e+01, -1.8302e+01,  ..., -1.9578e+01,
           -1.9578e+01, -1.9578e+01],
          [-6.2098e+01, -1.2097e+01, -3.6478e+00,  ..., -1.2098e+01,
           -1.2098e+01, -1.2087e+01],
          ...,
          [-6.2991e+01, -6.2991e+01, -6.2991e+01,  ..., -1.0136e+01,
           -6.2991e+01, -6.2991e+01],
          [-5.9735e+01, -5.9735e+01, -5.9735e+01,  ..., -2.9221e+00,
           -5.9735e+01, -5.9735e+01],
          [-5.7243e+01, -5.7243e+01, -5.7243e+01,  ...,  0.0000e+00,
           -5.7243e+01, -5.7243e+01]],
 
         [[-1.4718e+01, -1.3524e+01, -1.4718e+01,  ..., -2.3961e-05,
           -1.4718e+01, -1.0971e+01],
          [-7.4645e+00, -3.9639e+00, -7.4645e+00,  ..., -5.7465e+01,
           -7.4625e+00, -6.0858e-02],
          [-9.0136e+00, -1.5847e+00, -9.0136e+00,  ..., -5.9014e+01,
           -8.9822e+00, -5.9014e+01],
          ...,
    

In [153]:
hnet.cache.keys()

dict_keys(['embedding', 'decoder'])

In [148]:
x = torch.randn(1, 1, 12, 12)
out = net(x)
hout = hnet(x, store_hidden=True)
print(hnet.cache.keys())

dict_keys(['layers'])


In [149]:
print(hnet.cache["layers"])
print(hout)

tensor([[ 0.0494, -0.2356]])
tensor([[ 0.0494, -0.2356]])
