<a href="https://colab.research.google.com/github/bokchisojeong/bokchi_open_lab/blob/main/pytorch_hook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa: N812
import torch.optim as optim

In [43]:
_global_watch_idx = 0

In [59]:
def watch(models,idx=None):
  
  if not isinstance(models, (tuple, list)):
      models = (models,)
  global _global_watch_idx
  if idx is None:
    idx = _global_watch_idx
  prefix = ""
  for local_idx, model in enumerate(models):
    global_idx = idx + local_idx
    _global_watch_idx += 1
    if global_idx > 0:
      prefix = "graph_%i" % global_idx
    print(model)
    if not isinstance(model, torch.nn.Module):
        raise ValueError(
            "Expected a pytorch model (torch.nn.Module). Received "
            + str(type(model))
        )
    add_log_gradients_hook(model, name="", prefix=prefix, log_freq=1)
      
LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2)


def log_track_init(log_freq):
    """create tracking structure used by log_track_update"""
    l = [0] * 2
    l[LOG_TRACK_THRESHOLD] = log_freq
    return l
def log_track_update(log_track: int) -> bool:
    """count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached"""
    log_track[LOG_TRACK_COUNT] += 1
    if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]:
        return False
    log_track[LOG_TRACK_COUNT] = 0
    return True
    
def add_log_gradients_hook(module, name, prefix,log_freq=0):
  prefix = prefix + name
  
  if not hasattr(model, "_deepdriver_hook_names"):
    module._deepdriver_hook_names = []
      
  for name, parameter in model.named_parameters():
    print(parameter.requires_grad)
    if parameter.requires_grad:
        log_track_grad = log_track_init(log_freq)
        module._deepdriver_hook_names.append("gradients/" + prefix + name)
        _hook_variable_gradient_stats(
            parameter, "gradients/" + prefix + name, log_track_grad
        )
def _hook_variable_gradient_stats( var, name, log_track):
  """Logs a Variable's gradient's distribution statistics next time backward()
  is called on it.
  """

  def _callback(grad, log_track):
      if not log_track_update(log_track):
          print("log_track_update not")
          return
      print("callback")
      log_tensor_stats(grad.data, name)
  
  handle = var.register_hook(lambda grad: _callback(grad, log_track))
  print("r_hook_variable_gradient_stats register")
  _hook_handles[name] = handle
  return handle
def log_tensor_stats( tensor, name):
    print("log_tensor_stats")
    """Add distribution statistics on a tensor's elements to the current History entry"""
    # TODO Handle the case of duplicate names.

    if isinstance(tensor, tuple) or isinstance(tensor, list):
        while (isinstance(tensor, tuple) or isinstance(tensor, list)) and (
            isinstance(tensor[0], tuple) or isinstance(tensor[0], list)
        ):
            tensor = [item for sublist in tensor for item in sublist]
        tensor = torch.cat([t.reshape(-1) for t in tensor])

    # checking for inheritance from _TensorBase didn't work for some reason
    if not hasattr(tensor, "shape"):
        cls = type(tensor)
        raise TypeError(f"Expected Tensor, not {cls.__module__}.{cls.__name__}")

    # HalfTensors on cpu do not support view(), upconvert to 32bit
    if isinstance(tensor, torch.HalfTensor):
        tensor = tensor.clone().type(torch.FloatTensor).detach()

    # Sparse tensors have a bunch of implicit zeros. In order to histo them correctly,
    # we have to count them up and add them to the histo ourselves.
    sparse_zeros = None
    if tensor.is_sparse:
        # Have to call this on a sparse tensor before most other ops.
        tensor = tensor.cpu().coalesce().clone().detach()

        backing_values = tensor._values()
        non_zero_values = backing_values.numel()
        all_values = tensor.numel()
        sparse_zeros = all_values - non_zero_values
        tensor = backing_values

    flat = tensor.reshape(-1)

    # For pytorch 0.3 we use unoptimized numpy histograms (detach is new in 0.4)
    if not hasattr(flat, "detach"):
        tensor = flat.cpu().clone().numpy()
        wandb.run._log({name: wandb.Histogram(tensor)}, commit=False)
        return
    print(flat.is_cuda)
    print(flat.tolist())
    # if flat.is_cuda:
    #     # TODO(jhr): see if pytorch will accept something upstream to check cuda support for ops
    #     # until then, we are going to have to catch a specific exception to check for histc support.
    #     if self._is_cuda_histc_supported is None:
    #         self._is_cuda_histc_supported = True
    #         check = torch.cuda.FloatTensor(1).fill_(0)
    #         try:
    #             check = flat.histc(bins=self._num_bins)
    #         except RuntimeError as e:
    #             # Only work around missing support with specific exception
    #             # if str(e).startswith("_th_histc is not implemented"):
    #             #    self._is_cuda_histc_supported = False
    #             # On second thought, 0.4.1 doesnt have support and maybe there are other issues
    #             # lets disable more broadly for now
    #             self._is_cuda_histc_supported = False

    #     if not self._is_cuda_histc_supported:
    #         flat = flat.cpu().clone().detach()

    #     # As of torch 1.0.1.post2+nightly, float16 cuda summary ops are not supported (convert to float32)
    #     if isinstance(flat, torch.cuda.HalfTensor):
    #         flat = flat.clone().type(torch.cuda.FloatTensor).detach()

    # if isinstance(flat, torch.HalfTensor):
    #     flat = flat.clone().type(torch.FloatTensor).detach()

    # # Skip logging if all values are nan or inf or the tensor is empty.
    # if self._no_finite_values(flat):
    #     return

    # # Remove nans and infs if present. There's no good way to represent that in histograms.
    # flat = self._remove_infs_nans(flat)

    # tmin = flat.min().item()
    # tmax = flat.max().item()
    # if sparse_zeros:
    #     # If we've got zeros to add in, make sure zero is in the hist range.
    #     tmin = 0 if tmin > 0 else tmin
    #     tmax = 0 if tmax < 0 else tmax
    # # Anecdotally, this can somehow happen sometimes. Maybe a precision error
    # # in min()/max() above. Swap here to prevent a runtime error.
    # if tmin > tmax:
    #     tmin, tmax = tmax, tmin
    # tensor = flat.histc(bins=self._num_bins, min=tmin, max=tmax)
    # tensor = tensor.cpu().clone().detach()
    # bins = torch.linspace(tmin, tmax, steps=self._num_bins + 1)

    # # Add back zeroes from a sparse tensor.
    # if sparse_zeros:
    #     bins_np = bins.numpy()
    #     tensor_np = tensor.numpy()
    #     bin_idx = 0
    #     num_buckets = len(bins_np) - 1
    #     for i in range(num_buckets):
    #         start = bins_np[i]
    #         end = bins_np[i + 1]
    #         # There are 3 cases to consider here, all of which mean we've found the right bucket
    #         # 1. The bucket range contains zero.
    #         # 2. The bucket range lower bound *is* zero.
    #         # 3. This is the last bucket and the bucket range upper bound is zero.
    #         if (start <= 0 and end > 0) or (i == num_buckets - 1 and end == 0):
    #             bin_idx = i
    #             break

    #     tensor_np[bin_idx] += sparse_zeros
    #     tensor = torch.Tensor(tensor_np)
    #     bins = torch.Tensor(bins_np)

    # wandb.run._log(
    #     {name: wandb.Histogram(np_histogram=(tensor.tolist(), bins.tolist()))},
    #     commit=False,
    # )

In [61]:
# We will use Shakespeare Sonnet 2
test_sentence = """When forty winters shall besiege thy brow,
And dig deep trenches in thy beauty's field,
Thy youth's proud livery so gazed on now,
Will be a totter'd weed of small worth held:
Then being asked, where all thy beauty lies,
Where all the treasure of thy lusty days;
To say, within thine own deep sunken eyes,
Were an all-eating shame, and thriftless praise.
How much more praise deserv'd thy beauty's use,
If thou couldst answer 'This fair child of mine
Shall sum my count, and make my old excuse,'
Proving his beauty by succession thine!
This were to be new made when thou art old,
And see thy blood warm when thou feel'st it cold.""".split()
# we should tokenize the input, but we will ignore that for now
# build a list of tuples.  Each tuple is ([ word_i-2, word_i-1 ], target word)
trigrams = [
    ([test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2])
    for i in range(len(test_sentence) - 2)
]
_hook_handles ={}
module="torch.nn.Module"

CONTEXT_SIZE = 2
EMBEDDING_DIM = 10
vocab = set(test_sentence)
word_to_ix = {word: i for i, word in enumerate(vocab)}

class NGramLanguageModeler(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, sparse=True)
        self.linear1 = nn.Linear(context_size * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)

    def forward(self, inputs):
        print("forward")
        embeds = self.embeddings(inputs).view((1, -1))
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs

has_cuda = torch.cuda.is_available()
print(has_cuda)
losses = []
loss_function = nn.NLLLoss()
model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
model = model.cuda() if has_cuda else model
print(model)
optimizer = optim.SGD(model.parameters(), lr=0.001)
watch(model)

for _ in range(100):
    total_loss = 0
    for context, target in trigrams:

        # Step 1. Prepare the inputs to be passed to the model (i.e, turn the words
        # into integer indices and wrap them in tensors)
        context_idxs = torch.tensor(
            [word_to_ix[w] for w in context], dtype=torch.long
        )
        context_idxs = context_idxs.cuda() if has_cuda else context_idxs

        # Step 2. Recall that torch *accumulates* gradients. Before passing in a
        # new instance, you need to zero out the gradients from the old
        # instance
        model.zero_grad()

        # Step 3. Run the forward pass, getting log probabilities over next
        # words
        log_probs = model(context_idxs)

        # Step 4. Compute your loss function. (Again, Torch wants the target
        # word wrapped in a tensor)
        target = torch.tensor([word_to_ix[target]], dtype=torch.long)
        target = target.cuda() if has_cuda else target
        loss = loss_function(log_probs, target)

        # Step 5. Do the backward pass and update the gradient
        loss.backward()
        optimizer.step()

        # Get the Python number from a 1-element Tensor by calling tensor.item()
        total_loss += loss.item()
        print({"batch_loss": loss.item()})
    losses.append(total_loss)
print(losses)  # The loss decreased ev

False
NGramLanguageModeler(
  (embeddings): Embedding(97, 10, sparse=True)
  (linear1): Linear(in_features=20, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=97, bias=True)
)
NGramLanguageModeler(
  (embeddings): Embedding(97, 10, sparse=True)
  (linear1): Linear(in_features=20, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=97, bias=True)
)
True
r_hook_variable_gradient_stats register
True
r_hook_variable_gradient_stats register
True
r_hook_variable_gradient_stats register
True
r_hook_variable_gradient_stats register
True
r_hook_variable_gradient_stats register
forward
callback
log_tensor_stats
False
[0.007561386097222567, 0.008088135160505772, 0.011528730392456055, 0.009642762131989002, 0.010151883587241173, 0.005516549572348595, 0.011106746271252632, 0.006290119607001543, 0.00789384450763464, 0.01098714116960764, 0.008990545757114887, 0.011951266787946224, 0.005884855519980192, 0.00865113828331232, 0.009848086163401604,

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[0.00831477902829647, 0.005434995051473379, 0.0006739002419635653, 0.0069125196896493435, 0.0, 0.004208920057862997, 0.0, 0.004408484790474176, 0.0, 0.002089595654979348, 0.0, 0.0062518916092813015, 0.002797504886984825, 0.0, 0.004617590457201004, 0.0007612879853695631, 0.005152715835720301, 0.0, 0.006322747096419334, 0.0006083800108171999, 0.0, 0.003834935138002038, 0.0, 0.0, 0.0, 0.0, 0.007865991443395615, 0.0, 0.0, 0.0034523368813097477, 0.00022376448032446206, 0.003766182577237487, 0.0007387413643300533, 0.004394300747662783, 0.0, 0.0006446128827519715, 0.0, 0.002448915969580412, 0.0, 0.002003490924835205, 0.0, 0.004213126376271248, 0.0, 0.005191057454794645, 0.001413969905115664, 0.004945174790918827, 0.00023702133330516517, 0.003243096172809601, 0.0, 0.0007479226333089173, 0.0, 0.0015188368270173669, 0.0013306180480867624, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0022753027733415365, 0.0, 0.005371826235204935, 0.0, 0.0, 0.005221489351242781, 0.0, 0.0030420152470469475, 0.0006022176239639521, 0

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



KeyboardInterrupt: ignored