In [None]:
!git clone -b dev https://github.com/katcinskiy/stained-glass-transform-pytorch
!cd stained-glass-transform-pytorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
from datasets import load_dataset

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0')

In [3]:
model_name = "google/gemma-2-2b"  # or "meta-llama/Llama-3.2-1B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Fix for models without pad token

llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.00it/s]


In [5]:
tokenized = tokenizer('Heya', return_tensors='pt')

tokenized = {k: v.to(device) for k, v in tokenized.items()}

res = llm.generate(**tokenized, max_new_tokens=10)

W0813 13:34:29.048000 26981 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode


In [6]:
tokenizer.decode(token_ids=res[0], skip_special_tokens=True)

"Heya!\n\nI'm back with another drawing!"

# SGT Model

In [7]:
class SGT(nn.Module):
    def __init__(self, d, nhead=8, ff=4, layers=1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d, 
            nhead=nhead, 
            dim_feedforward=ff*d, 
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
        self.mu_head = nn.Linear(d, d)
        self.logvar_head = nn.Linear(d, d)
        
    def forward(self, x):
        hidden_embeds = self.enc(x)

        mu = self.mu_head(hidden_embeds)
        logvar = self.logvar_head(hidden_embeds)
        
        return mu, logvar
    
    def sample(self, x):
        mu, logvar = self(x)
        eps = torch.randn_like(mu)
        
        # Reparameterization trick
        z = x + mu + eps * torch.exp(0.5 * logvar)
        return z, mu, logvar

# Dataset

In [8]:
import datasets

DATASET_SIZE = 400

dataset = datasets.load_dataset('ag_news')

dataset = [item['text'] for item in tqdm(dataset['train'])]

dataset = dataset[:DATASET_SIZE]

100%|██████████| 120000/120000 [00:01<00:00, 105951.46it/s]


In [9]:
class SGTDataset(Dataset):
    """Custom dataset for SGT training"""
    def __init__(self, texts, tokenizer, max_length=128):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

# Loss

In [10]:
class SGTLoss(nn.Module):
    def __init__(
            self, 
            embedding_weights,
            alpha_mi,
            alpha_cos,
            alpha_norm,
            alpha_utility,
            alpha_obfuscation
        ):
        super().__init__()
        self.embeddings_norm = embedding_weights.norm(dim=-1).detach().median()
        self.alpha_mi = alpha_mi
        self.alpha_cos = alpha_cos
        self.alpha_norm = alpha_norm
        self.alpha_utility = alpha_utility
        self.alpha_obfuscation = alpha_obfuscation
    
    def forward(self, x, x_independent, llm, sgt):
        x_tilde, mu, logvar = sgt.sample(x)

        utility_loss = self._utility_loss(llm, x, x_tilde)

        mi_loss = self._mi_loss(x, x_independent, x_tilde, logvar, sgt)
        abs_cos_loss = self._abs_cos_loss(x, x_tilde)
        norm_loss = self._median_norm_penalty(x, mu)

        obfuscations_loss = self.alpha_mi * mi_loss + self.alpha_cos * abs_cos_loss + self.alpha_norm * norm_loss

        total_loss = self.alpha_utility * utility_loss + self.alpha_obfuscation * obfuscations_loss

        return {
            'total_loss': total_loss,
            'obfuscations_loss': obfuscations_loss,
            'utility_loss': utility_loss,
            'mi_loss': mi_loss,
            'abs_cos_loss': abs_cos_loss,
            'norm_loss': norm_loss
        }


    def _abs_cos_loss(self, x, x_tilde):
        cos_sim = F.cosine_similarity(x, x_tilde, dim=-1) # shape (b, l, d)
        return cos_sim.abs().mean()
    
    def _utility_loss(self, llm, x, x_tilde):
        with torch.no_grad():
            logits_clean = llm(inputs_embeds=x).logits
        
        logits_obf = llm(inputs_embeds=x_tilde).logits

        x_probas = F.softmax(logits_clean, dim=-1)
        
        x_tilde_log_probas = F.log_softmax(logits_obf, dim=-1)
        ce_loss = (-x_probas * x_tilde_log_probas).sum(dim=-1).mean()
        return ce_loss
    
    def _mi_loss(self, x, x_independent, x_tilde, logvar, sgt):

        mu_independent, logvar_independent = sgt(x_independent)
        
        # 1. Log determinant ratio

        # log_det_ratio = (logvar_l - logvar_i).sum(dim=(-1, -2)) # it was from chatgpt
        #TODO: check here, why logvar is a vector? do we need torch.linalg.det here?? why we have sum here? можем ли мы считать, что у нас матрица всегда диагональная? мы можем из СГТ всегда возвращать диагональную матрицу! итого это будут векторы
        log_det_ratio = (logvar - logvar_independent).sum(dim=(-1, -2))
        
        # 2. Махаланобис
        mahalanobis_distance = self._mahalanobis(x_tilde, x_independent, mu_independent, logvar_independent)
        
        # MI loss - среднее по батчу
        return (log_det_ratio + mahalanobis_distance).mean()
    
    def _mahalanobis(self, x_tilde, x_independent, mu_independent, logvar_independent):
        # TODO: test it

        vector_in_norm = (x_tilde - x_independent - mu_independent)

        logvar_independent_inverse = torch.exp(-logvar_independent)
        
        mahalanobis_distance = ((vector_in_norm ** 2) * logvar_independent_inverse).sum(dim=(-1, -2))

        return mahalanobis_distance

    def _median_norm_penalty(self, x, mu):
        norms = (x + mu).norm(dim=-1)
        
        penalty = (norms.mean() - self.embeddings_norm).abs()

        return penalty

In [11]:
sgt_loss = SGTLoss(
    embedding_weights=llm.model.embed_tokens.weight,
    alpha_mi=1,
    alpha_cos=1,
    alpha_norm=1,
    alpha_utility=1,
    alpha_obfuscation=1
)

In [12]:
class ObfuscationTrainer(Trainer):
    """Custom trainer for SGT with frozen LLM"""
    def __init__(self, sgt, llm, tokenizer, sgt_loss, **kwargs):
        super().__init__(**kwargs)
        self.sgt = sgt
        self.llm = llm
        self.tokenizer = tokenizer
        self.sgt_loss = sgt_loss
        
        # Freeze LLM parameters
        for p in self.llm.parameters():
            p.requires_grad_(False)
    
    def compute_loss(self, model, inputs, num_items_in_batch, return_outputs=False):

        # print(inputs)

        toks = inputs["input_ids"]
        
        with torch.no_grad():
            embeds = self.llm.get_input_embeddings()(toks)
        
        B = embeds.size(0)
        
        if B % 2 != 0:
            embeds = embeds[:-1]
            B = B - 1
        
        if B < 2:
            # Skip if batch too small
            loss = torch.tensor(0.0, device=embeds.device, requires_grad=True)
            return (loss, {"loss": loss}) if return_outputs else loss
        
        # Split batch for unbiased MI estimation
        B_half = B // 2
        x = embeds[:B_half]
        x_independent = embeds[B_half:B_half*2]

        loss_dict = self.sgt_loss(x, x_independent, self.llm, model)
        loss = loss_dict['total_loss']

        self.log({
                "train/total_loss": loss.item(),
                "train/obfuscations_loss": loss_dict['obfuscations_loss'].item(),
                "train/utility_loss": loss_dict['utility_loss'].item(),
                "train/mi_loss": loss_dict['mi_loss'].item(),
                "train/abs_cos_loss": loss_dict['abs_cos_loss'].item(),
                "train/norm_loss": loss_dict['norm_loss'].item(),
        })
        
        return loss

In [13]:
sgt = SGT(d=llm.config.hidden_size, nhead=8, ff=4, layers=2)
sgt = sgt.to(llm.device)

train_dataset = SGTDataset(dataset, tokenizer, max_length=128)

# Training arguments
training_args = TrainingArguments(
    output_dir="./sgt_model",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    num_train_epochs=3,
    fp16=True,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    warmup_steps=100,
    logging_dir="./logs",
    report_to="none",  # or "tensorboard"
    dataloader_drop_last=True,  # Ensure consistent batch sizes
    remove_unused_columns=False
)

# Initialize trainer
trainer = ObfuscationTrainer(
    model=sgt,
    sgt=sgt,
    llm=llm,
    tokenizer=tokenizer,
    sgt_loss=sgt_loss,
    args=training_args,
    train_dataset=train_dataset
)

# Train
trainer.train()

Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead.


OutOfMemoryError: CUDA out of memory. Tried to allocate 126.00 MiB. GPU 0 has a total capacity of 11.60 GiB of which 61.94 MiB is free. Including non-PyTorch memory, this process has 11.05 GiB memory in use. Of the allocated memory 10.75 GiB is allocated by PyTorch, with 70.00 MiB allocated in private pools (e.g., CUDA Graphs), and 39.57 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)