In [1]:
# !pip3 install torch
# !pip3 install transformers
# !pip3 install datasets
# !pip3 install tqdm
# !pip3 install accelerate
# !pip3 install tensorboardX

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import TrainerCallback

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
from datasets import load_dataset

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device('cuda:0')

In [4]:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Fix for models without pad token

llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device
)

In [5]:
tokenized = tokenizer('Heya', return_tensors='pt')

tokenized = {k: v.to(device) for k, v in tokenized.items()}

res = llm.generate(**tokenized, max_new_tokens=10, do_sample=False, temperature=None, top_p=None, top_k=None)

In [6]:
tokenizer.decode(token_ids=res[0], skip_special_tokens=True)

"Heya! I'm trying to create a program that can"

# SGT Model

In [None]:
class SGT(nn.Module):
    def __init__(self, d, nhead=8, ff=4, layers=1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d, 
            nhead=nhead, 
            dim_feedforward=ff*d, 
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
        self.mu_head = nn.Linear(d, d)
        self.logvar_head = nn.Linear(d, d)

        nn.init.zeros_(self.mu_head.weight)
        nn.init.zeros_(self.mu_head.bias)
        nn.init.zeros_(self.logvar_head.weight)
        nn.init.constant_(self.logvar_head.bias, -5.0) 
        
    def forward(self, x):
        hidden_embeds = self.enc(x)

        mu = self.mu_head(hidden_embeds)
        logvar = self.logvar_head(hidden_embeds)
        
        return mu, logvar
    
    def sample(self, x):
        mu, logvar = self(x)

        logvar = torch.clamp(logvar, min=-10, max=2)
        
        eps = torch.randn_like(mu)
        
        # Reparameterization trick
        z = x + mu + eps * torch.exp(0.5 * logvar)
        return z, mu, logvar

# Dataset

In [8]:
import datasets

DATASET_SIZE = 400

dataset = datasets.load_dataset('ag_news')

dataset = [item['text'] for item in tqdm(dataset['train'])]

dataset = dataset[:DATASET_SIZE]

100%|██████████| 120000/120000 [00:01<00:00, 101837.98it/s]


In [9]:
class SGTDataset(Dataset):
    """Custom dataset for SGT training"""
    def __init__(self, texts, tokenizer, max_length=128):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

# Loss

In [10]:
from loss import SGTLoss

In [11]:
class MinimalGenerationCallback(TrainerCallback):
    def __init__(self, sgt, llm, tokenizer, prompt="The weather today is"):
       self.sgt = sgt
       self.llm = llm
       self.tokenizer = tokenizer
       self.prompt = prompt
   
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 10 == 0 and state.global_step > 0:
            self.sgt.eval()
            self.llm.eval()
            
            with torch.no_grad():
                inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.llm.device)
                
                # Original
                orig_output = self.llm.generate(inputs["input_ids"], max_new_tokens=20, do_sample=False, temperature=None, top_p=None, top_k=None)
                orig_text = self.tokenizer.decode(orig_output[0], skip_special_tokens=True)
                
                # SGT
                original_embeds = self.llm.get_input_embeddings()(inputs["input_ids"])
                obfuscated_embeds, _, _ = self.sgt.sample(original_embeds)
                
                # FIX: Pass the input_ids along with inputs_embeds
                obf_output = self.llm.generate(
                    inputs_embeds=obfuscated_embeds, 
                    max_new_tokens=20,
                    input_ids=inputs["input_ids"],
                    do_sample=False, temperature=None, top_p=None, top_k=None
                )
                obf_text = self.tokenizer.decode(obf_output[0], skip_special_tokens=True)
                
                print(f"\n[Step {state.global_step}]:")
                print(f"Original:   {orig_text}")
                print(f"Obfuscated: {obf_text}")
            
            self.sgt.train()

In [12]:
sgt_loss = SGTLoss(
    embedding_weights=llm.model.embed_tokens.weight,
    alpha_mi=0.0,
    alpha_cos=1.0,
    alpha_norm=0.01,

    alpha_utility=1.0,
    alpha_obfuscation=0
)

In [13]:
class ObfuscationTrainer(Trainer):
    """Custom trainer for SGT with frozen LLM"""
    def __init__(self, sgt, llm, tokenizer, sgt_loss, **kwargs):
        print(kwargs)
        super().__init__(**kwargs)
        self.sgt = sgt
        self.llm = llm
        self.tokenizer = tokenizer
        self.sgt_loss = sgt_loss
        
        # Freeze LLM parameters
        for p in self.llm.parameters():
            p.requires_grad_(False)
    
    def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
        """Override 
          step to return proper outputs for compute_metrics"""
        model.eval()

        # return (None, None, None)

        # print("PREDICTION STEP")
        
        with torch.no_grad():
            loss = self.compute_loss(model, inputs, num_items_in_batch=None)
            dummy_predictions = torch.zeros(1)
            dummy_labels = torch.zeros(1)
            
        return (loss, dummy_predictions, dummy_labels)

    def compute_loss(self, model, inputs, num_items_in_batch, return_outputs=False):

        toks = inputs["input_ids"]
        
        with torch.no_grad():
            embeds = self.llm.get_input_embeddings()(toks)
        
        B = embeds.size(0)
        
        if B % 2 != 0:
            embeds = embeds[:-1]
            B = B - 1
        
        if B < 2:
            # Skip if batch too small
            loss = torch.tensor(0.0, device=embeds.device, requires_grad=True)
            return (loss, {"loss": loss}) if return_outputs else loss
        
        # Split batch for unbiased MI estimation
        B_half = B // 2
        x = embeds[:B_half]
        x_independent = embeds[B_half:B_half*2]

        attention_mask = inputs['attention_mask'][:B_half]

        loss_dict = self.sgt_loss(x, x_independent, self.llm, model, attention_mask)
        loss = loss_dict['total_loss']

        if loss < 0.2:
            self.sgt_loss.set_alpha('alpha_cos', 1.0)
            self.sgt_loss.set_alpha('alpha_obfuscation', 1.0)

        self.log({
            # Primary losses
            "total_loss": loss.item(),
            "obfuscations_loss": loss_dict['obfuscations_loss'].item(),
            
            # Raw component losses
            "raw/utility_loss": loss_dict['utility_loss'].item(),
            "raw/mi_loss": loss_dict['mi_loss'].item(),
            "raw/abs_cos_loss": loss_dict['abs_cos_loss'].item(),
            "raw/norm_loss": loss_dict['norm_loss'].item(),
            
            # Scaled (weighted) component losses
            "scaled/utility_loss": loss_dict['scaled_utility_loss'].item(),
            "scaled/mi_loss": loss_dict['scaled_mi_loss'].item(),
            "scaled/cos_loss": loss_dict['scaled_cos_loss'].item(),
            "scaled/norm_loss": loss_dict['scaled_norm_loss'].item(),
        })
        
        return (loss, None) if return_outputs else loss

In [14]:
sgt = SGT(d=llm.config.hidden_size, nhead=8, ff=2, layers=1)
sgt = sgt.to(llm.device)

In [15]:
train_dataset = SGTDataset(dataset, tokenizer, max_length=128)

# Metrics

In [16]:
eval_dataset = torch.utils.data.Subset(train_dataset, list(range(64)))

def _topk_intersection(logits1, logits2, attention_mask, k):

    topk_1 = logits1.topk(k, dim=-1)[1]
    topk_2 = logits2.topk(k, dim=-1)[1]
    
    intersection_size = torch.zeros_like(topk_1[..., 0])
    for i in range(k):
        for j in range(k):
            intersection_size += (topk_1[..., i] == topk_2[..., j])
    
    intersection_size = intersection_size.float() * attention_mask
    return intersection_size.sum(-1) / attention_mask.sum(-1) / k

def _cosine_similarity(x, x_tilde):
    return F.cosine_similarity(x, x_tilde, dim=-1).mean()

def reconstruction_rank(obf_embeds, clean_embeds, input_ids, mask):
    # Flatten to get all valid tokens
    valid_mask = mask.flatten() == 1
    obf_flat = obf_embeds.view(-1, obf_embeds.size(-1))[valid_mask]  # [num_valid, D]
    ids_flat = input_ids.flatten()[valid_mask]  # [num_valid]
    
    # Compute distances for all valid tokens at once
    distances = torch.cdist(obf_flat, clean_embeds)  # [num_valid, vocab_size]
    ranks = (torch.argsort(distances, dim=1) == ids_flat.unsqueeze(1)).nonzero()[:, 1] + 1
    
    return ranks.float().mean().item()

def compute_metrics_fn(eval_pred):
    sgt.eval()
    llm.eval()
    
    top1_scores = []
    top5_scores = []
    
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, 
        batch_size=8
    )
    
    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids = batch['input_ids'].to(llm.device)
            attention_mask = batch['attention_mask'].to(llm.device)

            orig_logits = llm(input_ids=input_ids, attention_mask=attention_mask).logits

            original_embeds = llm.get_input_embeddings()(input_ids)
            obfuscated_embeds, _, _ = sgt.sample(original_embeds)
            obf_logits = llm(inputs_embeds=obfuscated_embeds, attention_mask=attention_mask).logits

            # Your topk intersection logic here
            top1_score = _topk_intersection(orig_logits, obf_logits, attention_mask, k=1)
            top5_score = _topk_intersection(orig_logits, obf_logits, attention_mask, k=5)
            
            top1_scores.append(top1_score)
            top5_scores.append(top5_score)
    

    cosine_similarity = _cosine_similarity(original_embeds, obfuscated_embeds)
    avg_top1 = torch.stack(top1_scores).mean().item()
    avg_top5 = torch.stack(top5_scores).mean().item()
    
    sgt.train()
    
    return {
        "cosine_similarity": cosine_similarity,
        "reconstruction_rank": reconstruction_rank(obfuscated_embeds, original_embeds, input_ids, attention_mask),
        "top1_agreement": avg_top1,
        "top5_intersection": avg_top5
    }


In [17]:
len(eval_dataset)

64

In [19]:
training_args = TrainingArguments(
    output_dir="./sgt_model",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    num_train_epochs=1000,
    fp16=True,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    dataloader_drop_last=True,
    remove_unused_columns=False,
    weight_decay=0.01,
    eval_strategy="steps",  # or "epoch"
    eval_steps=10,  # Evaluate every 10 steps
    per_device_eval_batch_size=8,
)

# Initialize trainer
trainer = ObfuscationTrainer(
    model=sgt,
    sgt=sgt,
    llm=llm,
    tokenizer=tokenizer,
    sgt_loss=sgt_loss,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=torch.utils.data.Subset(train_dataset, list(range(64))),
    compute_metrics=compute_metrics_fn
)

callback = MinimalGenerationCallback(sgt, llm, tokenizer)
trainer.add_callback(callback)

trainer.train()

Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead.


{'model': SGT(
  (enc): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)
        )
        (linear1): Linear(in_features=1536, out_features=3072, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=3072, out_features=1536, bias=True)
        (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mu_head): Linear(in_features=1536, out_features=1536, bias=True)
  (logvar_head): Linear(in_features=1536, out_features=1536, bias=True)
), 'args': TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seeda

Step,Training Loss,Validation Loss,Cosine Similarity,Reconstruction Rank,Top1 Agreement,Top5 Intersection
10,7.7156,6.51923,0.210887,246.5,0.01709,0.068287
20,5.9929,5.216165,0.228814,181.352386,0.066923,0.127889
30,5.3073,4.714722,0.236267,135.971695,0.083039,0.153078
40,4.8029,3.987509,0.246643,282.74057,0.129976,0.201724
50,3.9721,3.396848,0.251457,226.466293,0.180859,0.257288
60,3.3116,2.798902,0.242996,216.044586,0.247437,0.311161
70,2.673,2.179086,0.237911,253.680359,0.31922,0.375879
80,2.1774,1.817789,0.242512,205.635208,0.387158,0.438668
90,1.6534,1.456937,0.241305,185.921677,0.44532,0.48906
100,1.2797,0.991224,0.24939,227.886169,0.51265,0.549153


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



[Step 10]:
Original:   The weather today is very nice. Which of the following options correctly expresses this idea?
A) The weather is nice.

Obfuscated: The weather today is
A. 100
B. 200
C. 30

[Step 20]:
Original:   The weather today is very nice. Which of the following options correctly expresses this idea?
A) The weather is nice.

Obfuscated: The weather today is a good way to improve your health and well-being. It can help you lose weight, reduce stress

[Step 30]:
Original:   The weather today is very nice. Which of the following options correctly expresses this idea?
A) The weather is nice.

Obfuscated: The weather today is a good idea for the following reasons:

1. **Cost-Effective**: Building a metal or concrete

[Step 40]:
Original:   The weather today is very nice. Which of the following options correctly expresses this idea?
A) The weather is nice.

Obfuscated: The weather today is be a good idea to have a backup plan in place. This could include having an emergency fund 

KeyboardInterrupt: 

In [None]:
tokenized = tokenizer("The weather", return_tensors='pt').to(device)

In [None]:
embeds = llm.model.embed_tokens(tokenized['input_ids'])

In [None]:
tokenizer.decode(llm.generate(inputs_embeds=embeds, do_sample=False, temperature=None, top_p=None, top_k=None)[0])

'man predicts a 70% chance of rain for Saturday and a 50%'