In [1]:
from composer import Trainer
from composer.core import Algorithm, Event

In [2]:
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss

class AWP:
    def __init__(
        self,
        model: Module, # State.model
        criterion: _Loss, # State.loss_func
        optimizer: Optimizer, # State.optimizer
        apex: bool, # infer from amp16
        adv_param: str="weight",
        adv_lr: float=1.0,
        adv_eps: float=0.01
    ) -> None:
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.apex = apex
        self.backup = {}
        self.backup_eps = {}

    def attack_backward(self, inputs: dict, label: Tensor) -> Tensor:
        with torch.cuda.amp.autocast(enabled=self.apex):
            self._save()
            self._attack_step()
            y_preds = self.model(inputs)
            adv_loss = self.criterion(
                y_preds.view(-1, 1), label.view(-1, 1))
            mask = (label.view(-1, 1) != -1)
            adv_loss = torch.masked_select(adv_loss, mask).mean()
            self.optimizer.zero_grad()
        return adv_loss

    def _attack_step(self) -> None:
        e = 1e-6
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                norm1 = torch.norm(param.grad)
                norm2 = torch.norm(param.data.detach())
                if norm1 != 0 and not torch.isnan(norm1):
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                    param.data.add_(r_at)
                    param.data = torch.min(
                        torch.max(
                            param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
                    )

    def _save(self) -> None:
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None and self.adv_param in name:
                if name not in self.backup:
                    self.backup[name] = param.data.clone()
                    grad_eps = self.adv_eps * param.abs().detach()
                    self.backup_eps[name] = (
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self) -> None:
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}

In [5]:
def _restore(
        model: Module,
        backup: dict) -> None:
    for name, param in model.named_parameters():
        if name in backup:
            param.data = backup[name]
def _save(
        model: Module,
        adv_param: str,
        adv_eps: float,
        backup: dict,
        backup_eps: dict) -> None:
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None and adv_param in name:
            if name not in backup:
                backup[name] = param.data.clone()
                grad_eps = adv_eps * param.abs().detach()
                backup_eps[name] = (
                    backup[name] - grad_eps,
                    backup[name] + grad_eps,
                )

def _attack_step(
        model: Module,
        adv_param: str,
        adv_lr: float,
        backup_eps: dict) -> None:
    e = 1e-6
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None and adv_param in name:
            norm1 = torch.norm(param.grad)
            norm2 = torch.norm(param.data.detach())
            if norm1 != 0 and not torch.isnan(norm1):
                r_at = adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                param.data.add_(r_at)
                param.data = torch.min(
                    torch.max(
                        param.data, backup_eps[name][0]), backup_eps[name][1]
                )

def _attack_backward(
        model: Module,
        criterion: _Loss,
        optimizer: Optimizer,
        inputs: dict, 
        label: Tensor, 
        apex: bool) -> Tensor:
    with torch.cuda.amp.autocast(enabled=apex):
        _save()
        _attack_step()
        y_preds = model(inputs)
        adv_loss = criterion(
            y_preds.view(-1, 1), label.view(-1, 1))
        mask = (label.view(-1, 1) != -1)
        adv_loss = torch.masked_select(adv_loss, mask).mean()
        optimizer.zero_grad()
    return adv_loss

In [11]:
class AWP(Algorithm):
    def __init__(self, 
                 start_epoch: int, 
                 criterion: _Loss, # State.loss_func
                 adv_param: str = 'weight',
                 adv_lr: float = 1.0,
                 adv_eps: float = 0.01,
                 apex: bool = True):
        self.start_epoch = start_epoch
        self.criterion = criterion
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.apex = apex
        self.backup = {}
        self.backup_eps = {}
    
    def match(self, event, state):
        return event == Event.AFTER_BACKWARD and state.timestamp.epoch >= self.start_epoch
    
    def apply(self, event, state, logger):
        inputs, label = state.batch
        state.loss = _attack_backward(
            state.model, 
            self.criterion, 
            state.optimizer, 
            inputs, 
            label, 
            self.apex)
        state.loss.backward()
        _restore(state.model, self.backup)
        self.backup, self.backup_eps = {}, {}

In [4]:
Algorithm??

[0;31mInit signature:[0m [0mAlgorithm[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mAlgorithm[0m[0;34m([0m[0mSerializable[0m[0;34m,[0m [0mABC[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""Base class for algorithms.[0m
[0;34m[0m
[0;34m    Algorithms are pieces of code which run at specific events (see :class:`.Event`) in the training loop.[0m
[0;34m    Algorithms modify the trainer's :class:`.State`, generally with the effect of improving the model's quality[0m
[0;34m    or increasing the efficiency and throughput of the training loop.[0m
[0;34m[0m
[0;34m    Algorithms must implement the following two methods:[0m
[0;34m      +----------------+-------------------------------------------------------------------------------+[0m
[0;34m      | Method         | Description                                                           

## Testing

In [6]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torchmetrics import PearsonCorrCoef, MeanSquaredError
from composer.models import HuggingFaceModel

checkpoint = "microsoft/deberta-v3-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
num_labels = 1


model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels=num_labels
)
model.resize_token_embeddings(len(tokenizer))
pears_corr = PearsonCorrCoef(num_outputs=num_labels)
mse_metric = MeanSquaredError()
composer_model = HuggingFaceModel(
    model=model,
    tokenizer=tokenizer,
    metrics=[pears_corr],
    eval_metrics=[mse_metric, pears_corr],
    use_logits=True,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at microsoft/deberta-v3-small were not used when initializing DebertaV2ForSequenceClassification: ['mask_predictions.dense.weight', 'mask_predictions.LayerNorm.bias', 'lm_predictions.lm_head.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.dense.weight', 'mask_predictions.classifier.bias', 'mask_predictions.dense.bias', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.dense.bias']
- This IS expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a B

In [7]:
model??

[0;31mSignature:[0m      [0mmodel[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           DebertaV2ForSequenceClassification
[0;31mString form:[0m   
DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128001, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-5): 6 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropou

## Questions


- How to handle for different precisions? AMP stuff and all
- How to infer loss function directly instead of asking user to provide it?
- Should we be saving the loss_fn directly?