In [96]:


import abc
import importlib
import logging
import math
import os
import sys
import wandb
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union

import torch
import transformers
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler, OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from transformers.trainer_utils import has_length

from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
    EvalFirstStepCallback,
    GPUStatsCallback,
    SaveAxolotlConfigtoWandBCallback,
    SaveBetterTransformerModelCallback,
    bench_eval_callback_factory,
    log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader, StreamingMultipackDistributedDataloaderNew
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
from axolotl.utils.distributed import (
    is_main_process,
)


try:
    import torch._dynamo  # pylint: disable=ungrouped-imports
except ImportError:
    pass

LOG = logging.getLogger("axolotl.core.trainer_builder")


class CandidatePenaltyCrossEntropyCriterion():
    """Applies a (1-p(x_nt)) loss to each negative target ('candidate') x_nt."""

    def __init__(self):
        self.padding_idx = 0
        self.IGNORE_TOKEN_ID = -100  # Copied from prompt_strategies

    def forward(self, target, pred_logits):
        shift_targets = target[..., 1:].contiguous()
        shift_targets = shift_targets.view(-1)
        shift_targets = shift_targets.masked_fill(shift_targets == self.IGNORE_TOKEN_ID, self.padding_idx)
        mask_ignore = shift_targets != self.padding_idx
        shift_targets = shift_targets[mask_ignore]
        print(shift_targets.shape)

        shift_logits = pred_logits[..., :-1, :].contiguous()
        shift_lprobs = F.log_softmax(shift_logits, dim=-1)
        shift_lprobs = shift_lprobs.view(-1, shift_lprobs.size(-1))
        shift_lprobs = shift_lprobs[mask_ignore]
        print(shift_lprobs.shape)

        # # Sanity check that this is the same as primary_loss.
        # sanity_mle_loss = F.nll_loss(
        #     shift_lprobs,
        #     shift_targets,
        #     reduction='mean')

        # -- unliklihood loss
        # Maximize (1 - p(x_nt)) for negative target tokens x_nt (equivalently minimize -log(1-p(x_nt)))

        # - form negative targets
        with torch.no_grad():
            # E.g. DABCC | D | EFFGD => {A,B,C} are negative targets.
            # Make 'the triangle'.
            # There's still a bug since we have packed batches: https://github.com/facebookresearch/unlikelihood_training/issues/11#issue-1630788451
            #ctx_cands = shift_targets.unsqueeze(0).expand(shift_targets.size(0), shift_targets.size(0))
            ctx_cands = shift_targets.unsqueeze(0).repeat(shift_targets.size(0), 1)
            print(ctx_cands)
            rows, cols = torch.triu_indices(shift_targets.size(0), shift_targets.size(0))
            ctx_cands[rows, cols] = 0
            print(ctx_cands)
            # Don't include the target for that timestep as a negative target.
            ctx_cands = ctx_cands.masked_fill(ctx_cands == shift_targets.unsqueeze(1), self.padding_idx)
            print(ctx_cands)
            negative_targets = torch.zeros_like(shift_lprobs).scatter_(1, ctx_cands, 1)

        # - compute loss
        one_minus_probs = torch.clamp((1.0 - shift_lprobs.exp()), min=1e-5)
        unliklihood_loss = -torch.log(one_minus_probs)*negative_targets
        unliklihood_loss = unliklihood_loss.sum(1).mean()
        return unliklihood_loss


In [103]:
# a a b b b pad pad pad
inputs = torch.tensor([[1, 1, 2, 2, 2, -100, -100, -100]])
pred_logits = torch.tensor([[[0.0, 0.0, 0.0],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                            [0.0, 0.9, 0.1],
                           ]])

loss_fn = CandidatePenaltyCrossEntropyCriterion()

In [104]:
loss_fn.forward(inputs, pred_logits)

torch.Size([4])
torch.Size([4, 3])
tensor([[1, 2, 2, 2],
        [1, 2, 2, 2],
        [1, 2, 2, 2],
        [1, 2, 2, 2]])
tensor([[0, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 2, 0, 0],
        [1, 2, 2, 0]])
tensor([[0, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]])


tensor(0.8673)