# Developing the training loop

now that i have a util function to patch gaze masks and a hook class to extract the attention logits, i need to bring them into the training loop

In [6]:
import sys
sys.path.append('..') # allows me to pull from GABRIL_utils/ which is a sibling directory to dev/

In [None]:
import torch
from utils.hook import AttentionExtractor
from GABRIL_utils.utils import load_dataset
from utils.patch_gaze_masks import patch_gaze_masks
import torch.nn.functional as F
import train_vit
from einops import rearrange

# Loss functions
## gaze predictions (model's QK_t logits)
- shape: `(batch_size, num_heads, num_patches)`
- raw logits, not softmaxed- this is ok bc thats what `F.cross_entropy()` wants

## gaze targets (human gaze masks)
- shape: `(batch_size, num_frames, num_patches)`
- already softmaxed across `num_patches`

`gaze_targs` needs to be broadcasted across the `num_heads` dim and the `num_frames`/`num_channels` dim needs to disappear. it shouldnt be the responsibility of the loss function to do the reshaping. loss function should be as dumb as possible, leaving the processing steps to train_step.


So assume for a dumb loss function that u magically get `(batch_size, num_heads, num_patches)` for both `gaze_preds` and `gaze_targs`

In [1]:
def _ce_only(self, action_preds, action_targs, **kwargs):
        return F.cross_entropy(action_preds, action_targs)

def _ce_plus_gaze_reg(self, action_preds, action_targs, gaze_preds, gaze_targs, reg_lambda=1.0, **kwargs):
        ''' averages over batch_size and num_heads in one go. if u wanna see the intermediary ce values for each head separately u need to define a separate loss function'''
        ce = self._ce_only(action_preds, action_targs)

        #  collapse the batch_size and num_heads dim into one that cross_entropy will average over to return u a scalar (this averages across the batch and heads so u dont need to do 2 separate averages)
        gaze_preds = rearrange(gaze_preds, 'b h n -> (b h) n')
        gaze_targs = rearrange(gaze_targs, 'b h n -> (b h) n')
        reg = reg_lambda * F.cross_entropy(gaze_preds, gaze_targs)

        return ce + reg

note that `_ce_plus_gaze_reg()` collapses the `batch_size` and `num_heads` dims into one. This is because `F.cross_entropy()` anyway averages across the batch examples to return u 1 scalar value instead of 1 cross entropy value for each example in the batch. Then if u regularized each head in some sort of for loop, you'd get num_heads total cross entropy values. this wont do bc backprop needs a single scalar value to do gradient descent on. the way i choose to tackle this is to have the cross entropy values for each head also averaged. This is two average computations that i can combine into one if i just collapse the `batch_size` and `num_heads` dims before i pass into `F.cross_entropy`, which will average across the first dimension. 

This approach still regularizes each head separately, which is what yutai wanted.

another option is to also mean pool the attention masks from each head before passing into `F.cross_entropy` but i think that only constrains the average mask (a head's attention map can drift if the other head(s) compensate). I don't think this would equate to regularizing each of the heads separately. but we'll test out both anyways