In [1]:
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ModelOutput:
    def __init__(
        self,
        pooler_output: torch.Tensor,
        last_hidden_state: torch.Tensor,
        hidden_states: torch.Tensor,
        attentions: torch.Tensor,
        cross_attentions: torch.Tensor    
    ) -> None:
        self.pooler_output = pooler_output
        self.last_hidden_state = last_hidden_state
        self.hidden_states = hidden_states
        self.attentions = attentions
        self.cross_attentions = cross_attentions

class ForwardPassOutput:
    def __init__(
        self,
        student_output = None,
        teacher_output = None,
        align_fuse: dict = None,
        labels: torch.Tensor = None,
        output_modalities: dict = None
    ) -> None:
        self.student_output = student_output
        self.teacher_output = teacher_output
        self.align_fuse = align_fuse
        self.labels = labels
        self.output_modalities = output_modalities
        
    def set_attributes(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)


In [35]:
student_output = ModelOutput(pooler_output=torch.randn(64,128), last_hidden_state=torch.randn(64,256,128), hidden_states=[torch.randn(64,256,128)]*8, attentions=[torch.randn(64,8,256,256)]*8, cross_attentions=[torch.randn(64,1,256,299)])
teacher_output = ModelOutput(pooler_output=torch.randn(64,128), last_hidden_state=torch.randn(64,256,128), hidden_states=[torch.randn(64,256,128)]*8, attentions=[torch.randn(64,8,256,256)]*8, cross_attentions=[torch.randn(64,1,256,299)])
outputs = ForwardPassOutput(student_output=student_output, teacher_output=teacher_output)

In [33]:
# https://www.baeldung.com/cs/instance-vs-batch-normalization

class LatentPredictionLoss(nn.Module):
    def __init__(
        self,
        num_hidden_layers_to_predict: int,
        reduction: str = "mean",
        beta: float = 1.0        
        ) -> None:
        super().__init__()
        
        self.loss_fn = nn.SmoothL1Loss(reduction=reduction, beta=beta)
        
        self.num_hidden_layers_to_predict = num_hidden_layers_to_predict
        
    
    def forward(
        self,
        fwd_output: ForwardPassOutput,
        ) -> torch.Tensor:
        
        # take the last transformer layers from the student
        x = fwd_output.student_output.hidden_states[-1:][0]
        # Follow the same layer normalization for all modalities
        x = [torch.layer_norm(tl.float(), tl.shape[-1:]) for tl in x]
        x = sum(x) / len(x)
        # normalize targets
        x = torch.layer_norm(x.float(), x.shape[-1:])
    
        
        
        with torch.no_grad():
            # take the last k transformer layers from the teacher
            y = fwd_output.teacher_output.hidden_states[-self.num_hidden_layers_to_predict:]
            # Follow the same layer normalization for all modalities
            y = [torch.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
            y = sum(y) / len(y)
            # normalize targets
            y = torch.layer_norm(y.float(), y.shape[-1:])
                
        hidden_states_loss = self.loss_fn(x, y)
        
        x_pooler = fwd_output.student_output.pooler_output
        y_pooler = fwd_output.teacher_output.pooler_output
        pooler_loss = self.loss_fn(x_pooler, y_pooler) 
        
        loss = hidden_states_loss + pooler_loss
                
        return loss

In [36]:
loss = LatentPredictionLoss(num_hidden_layers_to_predict=2)
loss(outputs)

  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
Traceback (most recent call last):
  File "_pydevd_bundle/pydevd_cython.pyx", line 1134, in _pydevd_bundle.pydevd_cython.PyDBFrame.trace_dispatch
  File "_pydevd_bundle/pydevd_cython.pyx", line 311, in _pydevd_bundle.pydevd_cython.PyDBFrame.do_wait_suspend
  File "c:\Users\marco\.venv\multimodal-ssl\lib\site-packages\debugpy\_vendored\pydevd\pydevd.py", line 2062, in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
  File "c:\Users\marco\.venv\multimodal-ssl\lib\site-packages\debugpy\_vendored\pydevd\pydevd.py", line 2098, in _do_wait_suspend
    time.sleep(0.01)
KeyboardInterrupt


KeyboardInterrupt: 

In [None]:
l1 = torch.nn.SmoothL1Loss(reduction='mean', beta=1.0)
mse = torch.nn.MSELoss()

In [None]:
l1(x, y)

In [None]:
mse(x, y)