In [1]:
import torch
import uuid
import numpy as np
import transformers
from torch.utils.data import DataLoader
from tqdm import tqdm
from accelerate import Accelerator
# import argparse
from torch.utils.tensorboard import SummaryWriter
import my_utils as ut
from sklearn.model_selection import train_test_split
from accelerate.utils import broadcast
import logging
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class argment():
    def __init__(self):
        self.num_beams = 1
        self.prefix_size = 50
        self.suffix_size = 50
        self.aligned = 1
        self.test_set_size = 1000
        self.model_size =  'gpt2'
        self.device = 'cude:0'
        self.train_preprefix = '../datasets/train_preprefix.npy'
        self.train_prefix = '../datasets/train_prefix.npy'
        self.train_suffix = '../datasets/train_suffix.npy'
        self.test_prefix = '../datasets/val_prefix.npy'
        self.bs = 16
        self.len_prompt = 5
        self.num_epochs = 15
        self.theta = 2

args = argment()

In [3]:
accelerator = Accelerator(mixed_precision='fp16')

# prepare datasets & dataloaders
DATASET_PATH = '../datasets'
prefixes =  np.concatenate((ut.load_prompts(f'{DATASET_PATH}/train_preprefix.npy'),\
    ut.load_prompts(f'{DATASET_PATH}/train_prefix.npy')), axis=1)[:, -args.prefix_size:]
suffixes = ut.load_prompts(f'{DATASET_PATH}/train_suffix.npy')[:, :args.suffix_size]

In [4]:
# sample a random training/test set
prefix_tr, prefix_test, suffix_tr, suffix_test = train_test_split(prefixes, suffixes, test_size=args.test_set_size)
# or use last 1k samples for deterministic evaluation
# prefix_tr, suffix_tr = prefixes[:-args.test_set_size], suffixes[:-args.test_set_size]
# prefix_test, suffix_test = prefixes[-args.test_set_size:], suffixes[-args.test_set_size:]

In [5]:
# prepending 50256 (eos token) to make multi-token soft-prompt learning work
train_ds = torch.cat([torch.full((len(prefix_tr), args.len_prompt), 50256),\
    torch.tensor(prefix_tr, dtype=torch.int64), torch.tensor(suffix_tr, dtype=torch.int64)], dim=1)
test_ds = torch.cat([torch.full((len(prefix_test), args.len_prompt), 50256),\
    torch.tensor(prefix_test, dtype=torch.int64), torch.tensor(suffix_test, dtype=torch.int64)], dim=1)
# make sure all GPUs see the same dataset split, which is what main process (GPU ID 0) has sampled
train_ds = broadcast(train_ds.cuda(), from_process=0) 
test_ds = broadcast(test_ds.cuda(), from_process=0) 
# dataloaders
train_loader = DataLoader(train_ds, batch_size=args.bs, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=args.bs)

In [6]:
# model
if args.model_size == 'small':
    MODEL_PATH = 'EleutherAI/gpt-neo-125M'
elif args.model_size == 'medium':
    MODEL_PATH = 'EleutherAI/gpt-neo-1.3B'
elif args.model_size == 'gpt2':
    MODEL_PATH = 'gpt2'
else:
    MODEL_PATH = 'EleutherAI/gpt-neo-2.7B'
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_PATH)

In [7]:
# freeze model params and add soft-prompting "layer"
for p in model.parameters():
    p.requires_grad=False
soft_prompt = ut.SoftEmbedding(model.get_input_embeddings(), n_tokens=args.len_prompt, initialize_from_vocab=True)
model.set_input_embeddings(soft_prompt)

In [8]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): SoftEmbedding(
      (wte): Embedding(50257, 768)
    )
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [9]:
optimizer = torch.optim.AdamW(params=[soft_prompt.learned_embedding], lr=5e-4, weight_decay=0)
# accelerator version of things
model, optimizer, train_loader, test_loader = accelerator.prepare(
    model, optimizer, train_loader, test_loader
)

In [10]:
# # creating tensorboard logger
# if accelerator.is_main_process:
#     file_name = f"""promptLearnAttack_id:{uuid.uuid1().hex}_lenPrompt:{args.len_prompt}_nEpochs:{args.num_epochs}_aligned:{args.aligned}"""\
#         + f"""_prefixSize:{args.prefix_size}_suffixSize:{args.suffix_size}_modelSize:{args.model_size}_numBeams:{args.num_beams}_"""
#     writer = SummaryWriter('../logs/' + file_name)

In [11]:
def evaluate_distributed(model, data_loader, args, accelerator):
    global loss
    """ get inference loss on supplied data loader (for distributed training) """
    model.eval()
    with torch.inference_mode():
        loss = []
        for batch in data_loader:
            with torch.no_grad():
                if args.aligned:
                    labels = torch.clone(batch)
                    # predicting only the last args.suffix_size tokens,
                    # so ignore everything else in loss calculation
                    labels[:, :labels.shape[1]-args.suffix_size] = -100
                else:
                    labels=batch
            outputs = model(input_ids=batch, labels=labels)
            loss.append(accelerator.gather(outputs.loss*len(batch)).cpu())
        # to match batch sizes, distributed training pad the last batch
        # we get rid of the extra samples by truncating
        loss = torch.tensor(loss)[:args.test_set_size]
        # loss = torch.cat(loss)[:args.test_set_size]
        return (torch.sum(loss) / args.test_set_size).item()

In [13]:
# training the prompt
for ep in range(args.num_epochs):
    model.train()
    tr_loss = []
    for i, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        with torch.no_grad():
            if args.aligned:
                labels = torch.clone(batch)
                # predicting only the last args.suffix_size tokens
                # so ignore everything else in loss calculation
                labels[:, :labels.shape[1]-args.suffix_size] = -100
            else:
                labels=batch
        outputs = model(input_ids=batch, labels=labels)
        accelerator.backward(outputs.loss)
        optimizer.step()
        tr_loss.append(accelerator.gather(outputs.loss*len(batch)).cpu())
    with torch.inference_mode():
        tr_loss = tr_loss[:len(train_loader.dataset)]
        tr_loss = (torch.sum(torch.tensor(tr_loss)) / len(train_loader.dataset)).item()
        tr_plp = np.exp(tr_loss)
        test_loss = evaluate_distributed(model, test_loader, args, accelerator)
        test_plp = np.exp(test_loss)
        if accelerator.is_main_process:
            accelerator.print('Train/Loss', tr_loss, ep)
            accelerator.print('Train/PLP', tr_plp, ep)
            accelerator.print('Test/Loss', test_loss, ep)
            accelerator.print('Test/PLP', test_plp, ep)
            accelerator.print(f'EP:{ep+1} Tr. Loss/PLP:{tr_loss:.3f}/{tr_plp:.3f}', end=' --- ')
            accelerator.print(f'Test Loss/PLP:{test_loss:.3f}/{test_plp:.3f}', end='\r')
        # if training loss has stablized around theta, finish training
        if tr_loss >= args.theta:
            break

100%|██████████| 875/875 [01:35<00:00,  9.21it/s]


Train/Loss 3.0227153301239014 0
Train/PLP 20.547007915528287 0
Test/Loss 2.506455659866333 0
Test/PLP 12.261394401223036 0
EP:1 Tr. Loss/PLP:3.023/20.547 --- Test Loss/PLP:2.506/12.261

In [None]:
# generate suffixes
generations_test = ut.generate_suffixes_distributed(model, test_loader, args, accelerator, use_cache=False)
generations_test = np.stack(generations_test, axis=0)
# always measure the final loss over suffix tokens
args.aligned = True
test_loss = evaluate_distributed(model, test_loader, args, accelerator)
# log results
if accelerator.is_main_process:
    # measure  fractional and exact match rates
    fract_rate, exact_rate = ut.compute_reconstruct_rate(generations_test, suffix_test, args)
    accelerator.print(f'Exact/Fract extract rate:{exact_rate:.3f}/{fract_rate:.3f}')
    test_plp = np.exp(test_loss)
    accelerator.print(f'Test Loss/PLP:{test_loss:.3f}/{test_plp:.3f}')
    accelerator.print('Memorization/Fract_Rate', fract_rate, 0)
    accelerator.print('Memorization/Exact_Rate', exact_rate, 0)
    accelerator.print('Test_Final/Loss', test_loss, 0)
    accelerator.print('Test_Final/PLP', np.exp(test_loss), 0)

 51%|█████     | 32/63 [01:05<01:05,  2.13s/it]