In [32]:
import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf
from tqdm import tqdm
!pip -q install datasets transformers pyvacy
from pyvacy import optim, analysis
# import tensorflow_privacy
# from tensorflow_privacy.privacy.analysis import compute_dp_sgd_privacy
import datasets
from datasets import load_dataset
from transformers import BertForSequenceClassification, BertTokenizer,default_data_collator,get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader

In [33]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
tf.random.set_seed(seed)

In [34]:
model = BertForSequenceClassification.from_pretrained("prajjwal1/bert-mini",num_labels=2)
tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-mini")
dataset = load_dataset("sst2")
# type(dataset)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-mini and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [35]:
prompt_init = "good movie means positive"
tokens = tokenizer(prompt_init,padding=True,return_tensors='pt',truncation=True,return_attention_mask=True)
prompt_embedds = model.get_input_embeddings()(tokens['input_ids'])
prompt_embedds[0].shape

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([6, 256])

In [36]:
#VARS
num_epochs = 7
lrn = 1e-3
BATCH = 256
max_length = 300
l2_norm_clip = 0.5
noise_multiplier = 0.571
# num_microbatches = 128

In [37]:
from pyvacy import optim, analysis
epsilon = analysis.moments_accountant(
    N=67349,
    batch_size=BATCH,
    noise_multiplier=noise_multiplier,
    epochs=num_epochs,
    delta=1e-6,
)
epsilon

7.987505470373655

In [38]:
#FUNCTIONS

def preprocess_function(examples):
    inputs = examples['sentence']
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="tf",return_attention_mask=True)
    # model_inputs = {
    #     'input_ids': tf.convert_to_tensor(model_inputs['input_ids'])
    #  }
    model_inputs["labels"] = examples['label']
    return model_inputs

processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    # remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset:   0%|          | 0/67349 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/872 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [39]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=BATCH, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=BATCH, pin_memory=True)

In [40]:
class Prompt_model(nn.Module):
  def __init__(self,model,prompt_init):
    super().__init__()
    self.model = model.to("cuda")
    print("bert params to be trained:-")
    for name,param in self.model.named_parameters():
      if(name[0] != "c"):
        param.requires_grad = False
      else:
        print(name)
    tokens = tokenizer(prompt_init,padding=True,return_tensors='pt',truncation=True,return_attention_mask=True).to("cuda")
    prompt_embedds = self.model.get_input_embeddings()(tokens['input_ids'])
    self.prompt = torch.from_numpy(prompt_embedds.cpu().numpy())[0].to("cuda")
    self.att = tokens['attention_mask']
    self.token_id = tokens['token_type_ids']
    # print(self.att.shape)
    self.get_prompt = nn.Sequential(
        nn.Linear(256,128),
        nn.ReLU(),
        nn.Linear(128,96),
        nn.ReLU(),
        nn.Linear(96,128),
        nn.ReLU(),
        nn.Linear(128,256)
    )
    # self.out = nn.Linear(512,1)

  def forward(self, inputs):
    token_embedds = self.model.get_input_embeddings()(inputs["input_ids"])
    # token_embedds = torch.from_numpy(token_embedds.numpy()).to("cuda")
    stacked_prompts = torch.stack([self.get_prompt(self.prompt)] * np.shape(token_embedds)[0], dim=0)
    stacked_attention = self.att.repeat(np.shape(token_embedds)[0], 1)
    stacked_ids = self.token_id.repeat(np.shape(token_embedds)[0],1)
    # print(self.att.shape)
    # print(stacked_attention.shape)
    # print(type(token_embedds))
    combined_embedds = torch.cat((stacked_prompts,token_embedds),1)
    combined_att = torch.cat((stacked_attention,inputs["attention_mask"]),1)
    combined_token_type_ids = torch.cat((stacked_ids,inputs['token_type_ids']),1)
    features = self.model(inputs_embeds=combined_embedds,labels=inputs["labels"],attention_mask = combined_att,token_type_ids = combined_token_type_ids)
    features.logits = torch.exp(features.logits)
    my_sum = torch.sum(features.logits,dim=1, keepdim=True)
    features.logits = features.logits/my_sum
    # x3 = torch.exp(torch.tensor(features))

    return features

In [41]:
prompt_init = "So nice movie - positive; the movie was bad - negative"
my_model = Prompt_model(model,prompt_init)
optimizer = optim.DPSGD(
    l2_norm_clip=l2_norm_clip,
    noise_multiplier=noise_multiplier,
    batch_size=BATCH,
    lr=lrn,
    momentum=True,
    params=my_model.parameters()
)
# lr_scheduler = get_linear_schedule_with_warmup(
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=(len(train_dataloader) * num_epochs),
# )

bert params to be trained:-
classifier.weight
classifier.bias


In [42]:
# tokens = tokenizer(prompt_init,padding=True,return_tensors='pt',truncation=True)
# prompt_embedds = model.get_input_embeddings()(tokens['input_ids'])
# model(inputs_embeds=prompt_embedds)
# dir(model.forward)

In [43]:
device = "cuda"
my_model = my_model.to(device)

for epoch in range(num_epochs):
    my_model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        # print(batch)
        # break
        outputs = my_model.forward(batch)
        # print(outputs)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()
        optimizer.zero_grad()
    # break
    my_model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = my_model.forward(batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

100%|██████████| 264/264 [03:24<00:00,  1.29it/s]
100%|██████████| 4/4 [00:01<00:00,  2.86it/s]


epoch=0: train_ppl=tensor(1.9700, device='cuda:0') train_epoch_loss=tensor(0.6780, device='cuda:0') eval_ppl=tensor(1.8839, device='cuda:0') eval_epoch_loss=tensor(0.6333, device='cuda:0')


100%|██████████| 264/264 [03:13<00:00,  1.36it/s]
100%|██████████| 4/4 [00:01<00:00,  2.84it/s]


epoch=1: train_ppl=tensor(1.8013, device='cuda:0') train_epoch_loss=tensor(0.5885, device='cuda:0') eval_ppl=tensor(1.7608, device='cuda:0') eval_epoch_loss=tensor(0.5658, device='cuda:0')


100%|██████████| 264/264 [03:13<00:00,  1.37it/s]
100%|██████████| 4/4 [00:01<00:00,  2.82it/s]


epoch=2: train_ppl=tensor(1.7964, device='cuda:0') train_epoch_loss=tensor(0.5858, device='cuda:0') eval_ppl=tensor(1.7390, device='cuda:0') eval_epoch_loss=tensor(0.5533, device='cuda:0')


100%|██████████| 264/264 [03:13<00:00,  1.36it/s]
100%|██████████| 4/4 [00:01<00:00,  2.85it/s]


epoch=3: train_ppl=tensor(1.7767, device='cuda:0') train_epoch_loss=tensor(0.5747, device='cuda:0') eval_ppl=tensor(1.7234, device='cuda:0') eval_epoch_loss=tensor(0.5443, device='cuda:0')


100%|██████████| 264/264 [03:13<00:00,  1.37it/s]
100%|██████████| 4/4 [00:01<00:00,  2.88it/s]


epoch=4: train_ppl=tensor(1.8348, device='cuda:0') train_epoch_loss=tensor(0.6070, device='cuda:0') eval_ppl=tensor(1.7226, device='cuda:0') eval_epoch_loss=tensor(0.5438, device='cuda:0')


100%|██████████| 264/264 [03:12<00:00,  1.37it/s]
100%|██████████| 4/4 [00:01<00:00,  2.86it/s]


epoch=5: train_ppl=tensor(1.8387, device='cuda:0') train_epoch_loss=tensor(0.6090, device='cuda:0') eval_ppl=tensor(1.7315, device='cuda:0') eval_epoch_loss=tensor(0.5490, device='cuda:0')


100%|██████████| 264/264 [03:13<00:00,  1.37it/s]
100%|██████████| 4/4 [00:01<00:00,  2.87it/s]

epoch=6: train_ppl=tensor(1.8161, device='cuda:0') train_epoch_loss=tensor(0.5967, device='cuda:0') eval_ppl=tensor(1.6923, device='cuda:0') eval_epoch_loss=tensor(0.5261, device='cuda:0')





In [44]:
correct = 0
total = 0
for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = my_model(batch)
            x = outputs.logits
            x = torch.exp(x)
            x = x/torch.sum(x,1,keepdim=True)
            for i in range(x.size()[0]):
              if(x[i][0] < x[i][1] and batch['labels'][i] == 1):
                correct += 1
              elif(x[i][0] > x[i][1] and batch['labels'][i] == 0):
                correct += 1
              total += 1
            # print(batch['labels'])
            # break
print("\n the accuracy is:-")
print(correct * 100 /total)

100%|██████████| 4/4 [00:01<00:00,  2.85it/s]


 the accuracy is:-
72.59174311926606





In [45]:
print("Number of Trainable Params:-")
print(sum(p.numel() for p in my_model.parameters() if p.requires_grad))

Number of Trainable Params:-
91234
