In [1]:
import argparse
import os
import torch
from peft import PeftModel, PeftConfig
from torch.optim import AdamW, SGD
from torch.utils.data import DataLoader
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    LoraConfig,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)

import evaluate
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from opacus import PrivacyEngine
from tqdm import tqdm

batch_size = 32
model_name_or_path = "roberta-large"
dataset_name_or_path = "stanfordnlp/snli"
task = "snli"
peft_type = PeftType.LORA
device = "cuda"
num_epochs = 20

In [2]:
datasets = load_dataset(dataset_name_or_path)

In [160]:
from transformers import RobertaConfig, RobertaTokenizer, RobertaForSequenceClassification, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_skip_modules=["classifier"]
)

model_name = "FacebookAI/roberta-large"
config = RobertaConfig.from_pretrained(
    model_name,
    num_labels=3,
)
tokenizer = RobertaTokenizer.from_pretrained(
    "FacebookAI/roberta-large",
    do_lower_case=False,
)
model = RobertaForSequenceClassification.from_pretrained(
    "FacebookAI/roberta-large",
    config=config,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [89]:
import torch
import torch.nn as nn
import transformers
from torch.utils.data import TensorDataset
from transformers.data.processors.utils import InputExample
from transformers.data.processors.glue import glue_convert_examples_to_features

LABEL_LIST = [0, 1, 2]

def _create_examples(dataset, set_type):
    """ Convert raw dataframe to a list of InputExample. Filter malformed examples
    """
    examples = []
    for index, item in enumerate(dataset):
        if item['label'] not in LABEL_LIST:
            continue
        if not isinstance(item['premise'], str) or not isinstance(item['hypothesis'], str):
            continue
        guid = f"{index}-{set_type}"
        examples.append(
            InputExample(guid=guid, text_a=item['premise'], text_b=item['hypothesis'], label=item['label']))
    return examples

def _dataset_to_features(dataset, set_type):
    """ Pre-process text. This method will:
    1) tokenize inputs
    2) cut or pad each sequence to MAX_SEQ_LENGHT
    3) convert tokens into ids

    The output will contain:
    `input_ids` - padded token ids sequence
    `attention mask` - mask indicating padded tokens
    `token_type_ids` - mask indicating the split between premise and hypothesis
    `label` - label
    """
    examples = _create_examples(dataset, set_type)

    #backward compatibility with older transformers versions
    legacy_kwards = {}
    from packaging import version
    if version.parse(transformers.__version__) < version.parse("2.9.0"):
        legacy_kwards = {
            "pad_on_left": False,
            "pad_token": tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            "pad_token_segment_id": 0,
        }

    return glue_convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        label_list=LABEL_LIST,
        max_length=128,
        output_mode="classification",
        **legacy_kwards,
    )

def _features_to_dataset(features):
    """ Convert features from `_df_to_features` into a single dataset
    """
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor(
        [f.attention_mask for f in features], dtype=torch.long
    )
    # all_token_type_ids = torch.tensor(
    #     [f.token_type_ids for f in features], dtype=torch.long
    # )
    all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    dataset = TensorDataset(
        all_input_ids, all_attention_mask, all_labels
    )

    return dataset


train_features = _dataset_to_features(datasets['train'], "train")
test_features = _dataset_to_features(datasets['test'], "test")

train_dataset = _features_to_dataset(train_features)
test_dataset = _features_to_dataset(test_features)



In [5]:
train_features[0]

InputFeatures(input_ids=[0, 250, 621, 15, 10, 5253, 13855, 81, 10, 3187, 159, 16847, 4, 2, 2, 250, 621, 16, 1058, 39, 5253, 13, 10, 1465, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=None, label=1)

In [161]:
BATCH_SIZE = 256
MAX_PHYSICAL_BATCH_SIZE = 64

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from opacus.utils.uniform_sampler import UniformWithReplacementSampler

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, sampler=RandomSampler(test_dataset), batch_size=BATCH_SIZE)

In [162]:
from peft import get_peft_model, LoraConfig, TaskType

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters count: {total_params:,}")


lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,  # our particular task is sequence classification
    inference_mode=False,  # Enable training mode
    r=8,  # Low-rank dimension
    lora_alpha=8,  # Alpha scaling factor
    lora_dropout=0.05,  # Dropout for LoRA layers
    target_modules=["query", "value"],
)

model_with_lora = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model_with_lora.parameters() if p.requires_grad)
print(f"Total trainable parameters with LoRA: {trainable_params:,}")

# classification_head = model_with_lora.base_model.model.classifier[:]

# for param in classification_head.parameters():
#     param.requires_grad = False

# FFA-LoRA modification: freeze all adapter A matrices so that only B matrices are trainable
for name, param in model_with_lora.named_parameters():
    if "lora_A" in name or "classifier" in name:
        param.requires_grad = False

# for name, param in model_with_lora.named_parameters():
#     if "classifier" in name:
#         param.requires_grad = False

trainable_params = sum(p.numel() for p in model_with_lora.parameters() if p.requires_grad)
print(f"Total trainable parameters with LoRA after freezing matrix A: {trainable_params:,}")

Total parameters count: 53,151,747
Total trainable parameters with LoRA: 1,839,107
Total trainable parameters with LoRA after freezing matrix A: 393,216


In [163]:
EPOCHS = 3
LOGGING_INTERVAL = 800 # once every how many steps we run evaluation cycle and report metrics
EPSILON = 1.0
DELTA = 1e-5 # Parameter for privacy accounting. Probability of not achieving privacy guarantees

In [164]:
import numpy as np
from tqdm.notebook import tqdm

# def accuracy(preds, labels):
#     return (preds == labels).mean()

def accuracy(preds, labels):
    # (preds == labels) returns a boolean tensor. Convert it to float and take the mean.
    return (preds == labels).float().mean().item()

# define evaluation cycle
def evaluate(model):
    model.eval()

    loss_arr = []
    accuracy_arr = []

    for batch in test_dataloader:
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                    #   'token_type_ids': batch[2],
                      'labels':         batch[2]}

            outputs = model(**inputs)
            loss, logits = outputs[:2]

            # preds = np.argmax(logits.detach().cpu().numpy(), axis=1)
            preds = torch.argmax(logits, dim=1)
            labels = inputs['labels']#.detach().cpu().numpy()

            loss_arr.append(loss.item())
            accuracy_arr.append(accuracy(preds, labels))

    model.train()
    avg_loss = sum(loss_arr) / len(loss_arr)
    avg_accuracy = sum(accuracy_arr) / len(accuracy_arr)
    return avg_loss, avg_accuracy

In [None]:
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager

MAX_GRAD_NORM = 2.0

device = torch.device("cuda:0")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, eps=1e-8)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
model = model_with_lora
model = model.train()

privacy_engine = PrivacyEngine()

model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    target_delta=DELTA,
    target_epsilon=EPSILON,
    epochs=EPOCHS,
    max_grad_norm=MAX_GRAD_NORM,
)

model = model.to(device)
model = model.train()



In [166]:
for epoch in range(1, EPOCHS+1):
    losses = []

    with BatchMemoryManager(
        data_loader=train_dataloader,
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
        optimizer=optimizer
    ) as memory_safe_data_loader:
        for step, batch in enumerate(tqdm(memory_safe_data_loader)):
            optimizer.zero_grad()

            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    # 'token_type_ids': batch[2],
                    'labels':         batch[2]}

            outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

            loss = outputs[0]
            loss.backward()

            losses.append(loss.item())

            optimizer.step()

            if step > 0 and step % LOGGING_INTERVAL == 0:
                train_loss = np.mean(losses)
                eps = privacy_engine.get_epsilon(DELTA)

                eval_loss, eval_accuracy = evaluate(model)

                print(
                  f"Epoch: {epoch} | "
                  f"Step: {step} | "
                  f"Train loss: {train_loss:.3f} | "
                  f"Eval loss: {eval_loss:.3f} | "
                  f"Eval accuracy: {eval_accuracy:.3f} | "
                  f"ɛ: {eps:.2f} "
                )

  0%|          | 0/8583 [00:00<?, ?it/s]



Epoch: 1 | Step: 800 | Train loss: 1.105 | Eval loss: 1.100 | Eval accuracy: 0.327 | ɛ: 0.34 
Epoch: 1 | Step: 1600 | Train loss: 1.105 | Eval loss: 1.097 | Eval accuracy: 0.362 | ɛ: 0.42 
Epoch: 1 | Step: 2400 | Train loss: 1.104 | Eval loss: 1.093 | Eval accuracy: 0.379 | ɛ: 0.49 
Epoch: 1 | Step: 3200 | Train loss: 1.101 | Eval loss: 1.059 | Eval accuracy: 0.447 | ɛ: 0.53 
Epoch: 1 | Step: 4000 | Train loss: 1.083 | Eval loss: 0.879 | Eval accuracy: 0.558 | ɛ: 0.57 
Epoch: 1 | Step: 4800 | Train loss: 1.051 | Eval loss: 0.838 | Eval accuracy: 0.637 | ɛ: 0.60 
Epoch: 1 | Step: 5600 | Train loss: 1.017 | Eval loss: 0.813 | Eval accuracy: 0.687 | ɛ: 0.63 
Epoch: 1 | Step: 6400 | Train loss: 0.983 | Eval loss: 0.688 | Eval accuracy: 0.750 | ɛ: 0.66 
Epoch: 1 | Step: 7200 | Train loss: 0.951 | Eval loss: 0.613 | Eval accuracy: 0.792 | ɛ: 0.68 
Epoch: 1 | Step: 8000 | Train loss: 0.922 | Eval loss: 0.598 | Eval accuracy: 0.811 | ɛ: 0.70 
Epoch: 1 | Step: 8800 | Train loss: 0.896 | Eval lo

  0%|          | 0/8583 [00:00<?, ?it/s]

Epoch: 2 | Step: 800 | Train loss: 0.630 | Eval loss: 0.551 | Eval accuracy: 0.846 | ɛ: 0.75 
Epoch: 2 | Step: 1600 | Train loss: 0.628 | Eval loss: 0.554 | Eval accuracy: 0.849 | ɛ: 0.77 
Epoch: 2 | Step: 2400 | Train loss: 0.625 | Eval loss: 0.559 | Eval accuracy: 0.855 | ɛ: 0.78 
Epoch: 2 | Step: 3200 | Train loss: 0.622 | Eval loss: 0.558 | Eval accuracy: 0.856 | ɛ: 0.80 
Epoch: 2 | Step: 4000 | Train loss: 0.620 | Eval loss: 0.566 | Eval accuracy: 0.856 | ɛ: 0.81 
Epoch: 2 | Step: 4800 | Train loss: 0.619 | Eval loss: 0.540 | Eval accuracy: 0.862 | ɛ: 0.82 
Epoch: 2 | Step: 5600 | Train loss: 0.617 | Eval loss: 0.539 | Eval accuracy: 0.865 | ɛ: 0.84 
Epoch: 2 | Step: 6400 | Train loss: 0.615 | Eval loss: 0.541 | Eval accuracy: 0.865 | ɛ: 0.85 
Epoch: 2 | Step: 7200 | Train loss: 0.614 | Eval loss: 0.537 | Eval accuracy: 0.867 | ɛ: 0.86 
Epoch: 2 | Step: 8000 | Train loss: 0.612 | Eval loss: 0.537 | Eval accuracy: 0.867 | ɛ: 0.87 
Epoch: 2 | Step: 8800 | Train loss: 0.611 | Eval lo

  0%|          | 0/8583 [00:00<?, ?it/s]

Epoch: 3 | Step: 800 | Train loss: 0.585 | Eval loss: 0.531 | Eval accuracy: 0.872 | ɛ: 0.90 
Epoch: 3 | Step: 1600 | Train loss: 0.583 | Eval loss: 0.539 | Eval accuracy: 0.871 | ɛ: 0.91 
Epoch: 3 | Step: 2400 | Train loss: 0.586 | Eval loss: 0.532 | Eval accuracy: 0.871 | ɛ: 0.92 
Epoch: 3 | Step: 3200 | Train loss: 0.587 | Eval loss: 0.529 | Eval accuracy: 0.872 | ɛ: 0.93 
Epoch: 3 | Step: 4000 | Train loss: 0.587 | Eval loss: 0.523 | Eval accuracy: 0.874 | ɛ: 0.94 
Epoch: 3 | Step: 4800 | Train loss: 0.586 | Eval loss: 0.520 | Eval accuracy: 0.873 | ɛ: 0.95 
Epoch: 3 | Step: 5600 | Train loss: 0.585 | Eval loss: 0.528 | Eval accuracy: 0.871 | ɛ: 0.96 
Epoch: 3 | Step: 6400 | Train loss: 0.585 | Eval loss: 0.520 | Eval accuracy: 0.874 | ɛ: 0.96 
Epoch: 3 | Step: 7200 | Train loss: 0.583 | Eval loss: 0.519 | Eval accuracy: 0.873 | ɛ: 0.97 
Epoch: 3 | Step: 8000 | Train loss: 0.584 | Eval loss: 0.511 | Eval accuracy: 0.874 | ɛ: 0.98 
Epoch: 3 | Step: 8800 | Train loss: 0.583 | Eval lo