In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import TrainerCallback

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
from datasets import load_dataset

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0')

In [3]:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Fix for models without pad token

llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device
)

In [4]:
import datasets

DATASET_SIZE = 400

dataset = datasets.load_dataset('ag_news')

dataset = [item['text'] for item in tqdm(dataset['train'])]

dataset = dataset[:DATASET_SIZE]

class SGTDataset(Dataset):
    """Custom dataset for SGT training"""
    def __init__(self, texts, tokenizer, max_length=128):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

100%|██████████| 120000/120000 [00:01<00:00, 105457.48it/s]


In [5]:
from transformers.modeling_outputs import ModelOutput

class ObfOutput(ModelOutput):
    loss: torch.Tensor
    # optional: surface individual losses to logs via callbacks if wanted
    mi_loss: torch.Tensor
    utility_loss: torch.Tensor
    abs_cos_loss: torch.Tensor
    norm_loss: torch.Tensor
    obfuscations_loss: torch.Tensor

class ObfuscationModel(nn.Module):
    def __init__(self, llm, sgt, sgt_loss):
        super().__init__()
        self.llm = llm
        self.sgt = sgt
        self.sgt_loss = sgt_loss

        # Freeze LLM parameters
        for p in self.llm.parameters():
            p.requires_grad_(False)

        self.last_metrics = {}

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        with torch.no_grad():
            embeds = self.llm.get_input_embeddings()(input_ids)
        
        B = embeds.size(0)
        
        if B % 2 != 0:
            embeds = embeds[:-1]
            B = B - 1
        
        if B < 2:
            loss = torch.tensor(0.0, device=embeds.device, requires_grad=True)
            return loss
        
        B_half = B // 2
        x = embeds[:B_half]
        x_independent = embeds[B_half:B_half*2]

        attention_mask = attention_mask[:B_half]

        loss_dict = self.sgt_loss(x, x_independent, self.llm, self.sgt, attention_mask)
        loss = loss_dict['total_loss']

        if loss < 0.2:
            self.sgt_loss.set_alpha('alpha_cos', 1.0)
            self.sgt_loss.set_alpha('alpha_obfuscation', 1.0)

        self.last_metrics = {
            "fuck": 12.0,
            'train/fucking': 1.0,
            'eval/fuc': 2.0
        }
        # self.log({
        #     # Primary losses
        #     "total_loss": loss.item(),
        #     "obfuscations_loss": loss_dict['obfuscations_loss'].item(),
            
        #     # Raw component losses
        #     "raw/utility_loss": loss_dict['utility_loss'].item(),
        #     "raw/mi_loss": loss_dict['mi_loss'].item(),
        #     "raw/abs_cos_loss": loss_dict['abs_cos_loss'].item(),
        #     "raw/norm_loss": loss_dict['norm_loss'].item(),
            
        #     # Scaled (weighted) component losses
        #     "scaled/utility_loss": loss_dict['scaled_utility_loss'].item(),
        #     "scaled/mi_loss": loss_dict['scaled_mi_loss'].item(),
        #     "scaled/cos_loss": loss_dict['scaled_cos_loss'].item(),
        #     "scaled/norm_loss": loss_dict['scaled_norm_loss'].item(),
        # })
        
        return (loss, )

    def sgt_forward(self, x):
        return self.sgt(x)
    
    def sgt_sample(self, x):
        return self.sgt.sample(x)

In [6]:
class SGT(nn.Module):
    def __init__(self, d, nhead=8, ff=4, layers=1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d, 
            nhead=nhead, 
            dim_feedforward=ff*d, 
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
        self.mu_head = nn.Linear(d, d)
        self.logvar_head = nn.Linear(d, d)

        nn.init.zeros_(self.mu_head.weight)
        nn.init.zeros_(self.mu_head.bias)
        nn.init.zeros_(self.logvar_head.weight)
        nn.init.constant_(self.logvar_head.bias, -5.0) 
        
    def forward(self, x):
        hidden_embeds = self.enc(x)

        mu = self.mu_head(hidden_embeds)
        logvar = self.logvar_head(hidden_embeds)
        
        return mu, logvar
    
    def sample(self, x):
        mu, logvar = self(x)

        logvar = torch.clamp(logvar, min=-10, max=2)
        
        eps = torch.randn_like(mu)
        
        # Reparameterization trick
        z = x + mu + eps * torch.exp(0.5 * logvar)
        return z, mu, logvar

In [7]:
sgt = SGT(d=llm.config.hidden_size, nhead=8, ff=2, layers=1)
sgt = sgt.to(llm.device)

In [8]:
from loss import SGTLoss

sgt_loss = SGTLoss(
    embedding_weights=llm.model.embed_tokens.weight,
    alpha_mi=0.0,
    alpha_cos=1.0,
    alpha_norm=0.01,

    alpha_utility=1.0,
    alpha_obfuscation=0
)

In [9]:
train_dataset = SGTDataset(dataset, tokenizer, max_length=64)

In [10]:
model = ObfuscationModel(llm, sgt, sgt_loss)

In [11]:
class CustomMetricsCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        # Clear model metrics at the start of each step
        if hasattr(self.model, 'last_metrics'):
            self.model.last_metrics = {}
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        # This is called when trainer is about to log
        if logs is not None and hasattr(self.model, 'last_metrics'):
            # Add your metrics to the logs dict BEFORE they go to TensorBoard
            logs.update(self.model.last_metrics)

In [12]:
training_args = TrainingArguments(
    output_dir="./sgt_model",
    per_device_train_batch_size=10,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    num_train_epochs=1000,
    fp16=True,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    dataloader_drop_last=True,
    remove_unused_columns=False,
    eval_strategy="steps",
    eval_steps=10,
    per_device_eval_batch_size=8,
)

metrics_callback = CustomMetricsCallback()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=torch.utils.data.Subset(train_dataset, list(range(64))),
    callbacks=[metrics_callback]
)


trainer.train()

AttributeError: 'CustomMetricsCallback' object has no attribute 'model'