In [None]:
import torch.nn as nn

In [None]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=1.0, alpha=0.95):
        super(DistillationLoss, self).__init__()
        self.temperature = nn.Parameter(torch.tensor(temperature))
        self.alpha = alpha
        self.loss_function = CustomMaskedLoss()

    def forward(self, student_scores, student_logits, teacher_logits, labels, pipelines):
        classification_loss = self.loss_function(student_scores, labels, pipelines)
        distillation_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1))
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * classification_loss
        return total_loss

In [None]:
class CustomMaskedLoss(nn.Module):
    def __init__(self):
        super(CustomMaskedLoss, self).__init__()
        # weight = 4.0
        # self.loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weight), reduction='none') if is_bs_finetune else nn.BCEWithLogitsLoss(reduction='none')
        self.loss = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, output, label, pipelines):
        # create mask for relevant labels
        mask = torch.ones_like(label)

        if is_bs_finetune:
            # set brand stuffing to 1, rest to 0. 
            mask[:, 0:] = 0
            mask[:, 0] = 1
        
        # if pipeline is empty, then mask out all label except for "IsSpam"
        for idx, pipeline in enumerate(pipelines):
            if not is_bs_finetune:
                # get pipeline type. Position of "empty", rest will be 0. 
                if sum(pipeline) == 1 and pipeline[spam_subtype_reverse_mapping["Empty"]] == 1:
                    assert label[idx, spam_subtype_reverse_mapping["IsSpam"]] == 1
                    mask[idx, :spam_subtype_reverse_mapping["IsSpam"]] = 0
                    mask[idx, spam_subtype_reverse_mapping["IsSpam"]] = 1

        # calculate loss and apply mask
        loss = self.loss(output, label) # output is (batch_size, num_labels), label is (batch_size, num_labels). l
        loss = loss * mask # 5 dim

        nonzero_loss = loss[loss != 0]
        return nonzero_loss.mean() if nonzero_loss.nelement() != 0 else torch.tensor(0.0, device=device)
