In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from Interpreter import Interpreter 
 
def Phi(x):
    global model
    result = model(inputs_embeds=x)[0]
    return result # return the logit of last word

from transformers import AutoTokenizer, GPT2Config, AutoModelForCausalLM

config = GPT2Config.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=not False)
model = AutoModelForCausalLM.from_pretrained('gpt2', from_tf=bool(".ckpt" in 'gpt2'), config=config)

#Normal Wikitext
save_path = f'models/{model.__class__.__name__}_gpt2_wikitext.pt'
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)

input_embedding_weight_std = (
    model.get_input_embeddings().weight.view(1,-1)
    .std().item()
)
 
text = "the secret number is"
inputs = tokenizer.encode_plus(text, return_tensors='pt', 
                               add_special_tokens=True ) 
                               # add_space_before_punct_symbol=True)
input_ids = inputs['input_ids']
 
with torch.no_grad():
    x = model.get_input_embeddings()(input_ids).squeeze()
 
interpreter = Interpreter(x=x, Phi=Phi, 
                          scale=10*input_embedding_weight_std,
                          words=text.split(' ')).to(model.device)
 
# This will take sometime.
interpreter.optimize(iteration=1, lr=0.01, show_progress=True)
interpreter.visualize()
interpreter.get_sigma()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class argument:
    def __init__(self):
        self.dataset_name = 'wikitext'
        self.dataset_config_name = 'wikitext-2-raw-v1'
        self.output_dir = './logs/' 
        self.seed = 1234
        self.learning_rate = 5e-5
        self.block_size = 1024 
        self.do_ref_model = False
        
        self.config_name = None
        self.model_name_or_path = 'gpt2'
        self.tokenizer_name = 'gpt2'
        self.use_slow_tokenizer = False
        
        self.per_device_train_batch_size = 8
        self.per_device_eval_batch_size = 8
        self.gradient_accumulation_steps = 8
        
        self.do_ref_model = False
        self.lr_scheduler_type = 'linear'

        self.num_train_epochs = 5
        self.max_train_steps = None

        self.preprocessing_num_workers = 1
        self.overwrite_cache = False
        self.weight_decay = 0.0
        self.num_warmup_steps = 0
        
        self.add_canary = True
        self.canary_rep = 50
        self.canary_len = 5
        
        self.add_adapter = False
        self.adapter_reduction = 16
        self.train_head_only = False
        self.train_layer_n_only = None 
        self.redact_token = 'multi'
         
args = argument()

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, GPT2Config
class CustomGPT2HeadModel(nn.Module):
    def __init__(self, config):
        super(CustomGPT2HeadModel, self).__init__()
        self.transformer = AutoModelForCausalLM.from_pretrained(
                                args.model_name_or_path,
                                # output_hidden_states=True,
                                from_tf=bool(".ckpt" in args.model_name_or_path),
                                config=config,
                            )
        self.pv_embed    = nn.Embedding(2, config.n_embd)
        self.alpha       = 0.8
   
    def forward(self, 
                input_ids = None, 
                inputs_embeds = None,
                private_ids=None, 
                attention_mask=None, 
                labels = None):
                    
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds
        else:
            # Get token embeddings from GPT-2
            inputs_embeds = self.transformer.transformer.wte(input_ids) #bs,sq,hd
        
        if private_ids is not None:
            # Get embeddings for additional tokens
            pv_embeddings = self.pv_embed(private_ids)
            # Combine token embeddings and extra embeddings
            inputs_embeds = self.alpha * inputs_embeds + (1 - self.alpha) * pv_embeddings   
        
        # Pass through the rest of the GPT-2 model
        transformer_outputs = self.transformer(
            # input_ids = input_ids,
            inputs_embeds = inputs_embeds, 
            attention_mask = attention_mask,
            labels = labels,
            output_hidden_states = True
            )
        
        return transformer_outputs

config = GPT2Config.from_pretrained('gpt2')
model = CustomGPT2HeadModel(config)

#Private Wikitext
save_path = f'models/{model.__class__.__name__}_gpt2_wikitext_pv.pt'
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)
# model = model.to(devcie)

In [None]:
# Set the device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_embedding_weight_std = (
    model.transformer.get_input_embeddings().weight.view(1,-1)
    .std().item()
)

text = "the secret number is"
inputs = tokenizer.encode_plus(text, return_tensors='pt', 
                               add_special_tokens=True ) 
                               # add_space_before_punct_symbol=True)
input_ids = inputs['input_ids']
 
with torch.no_grad():
    x = model.transformer.get_input_embeddings()(input_ids).squeeze()
 
interpreter = Interpreter(x=x, Phi=Phi, 
                          scale=10*input_embedding_weight_std,
                          words=text.split(' ')).to(model.transformer.device)
 
# This will take sometime.
interpreter.optimize(iteration=1, lr=0.01, show_progress=True)
interpreter.visualize()
interpreter.get_sigma()