# Model inversion attack against SGT

This code explores how a model inversion attack might work against the SGT. My main idea was to test the following: even if SGT corrupts the input embeddings, the fact that the model still functions means we can potentially exploit its prior knowledge to reconstruct the original input.

Code of training inverison model is available at https://github.com/katcinskiy/model-inversion-attack

Pretrain for inverion model: https://drive.google.com/drive/folders/10P259HD9siA4foxBeN8c_pIdKeKmWM8R?usp=share_link

Pretrain for SGT: https://drive.google.com/drive/folders/1j-h2Xz7KWn1xgxJFYqoAt0kMZz-ogRry?usp=share_link

In [1]:
import sys
sys.path.append('..')

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BartForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0')

In [None]:
LLM_NAME = "Qwen/Qwen2.5-1.5B"
INVERSE_MODEL_NAME = "facebook/bart-base"

llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
base_llm = AutoModelForCausalLM.from_pretrained(LLM_NAME, device_map=device)

inv_tokenizer = AutoTokenizer.from_pretrained(INVERSE_MODEL_NAME)
inv_base_model = BartForConditionalGeneration.from_pretrained(INVERSE_MODEL_NAME, device_map=device)

## Loading pretrained SGT

In [None]:
from sgt_model import SGTModel

sgt = SGTModel(1536, 8, 2, 1, None, None, None, None).to(device)

sgt.load_state_dict(torch.load('pretrain_sgt_cos_only.pt'))
# sgt.load_state_dict(torch.load('pretrain_sgt_cos_and_mselogvar.pt'))

mu_init_weight is None - skipping mu_head.weight initialization
mu_init_bias is None - skipping mu_head.bias initialization
logvar_init_weight is None - skipping logvar_head.weight initialization
logvar_init_bias is None - skipping logvar_head.bias initialization


<All keys matched successfully>

In [5]:
class LLM(torch.nn.Module):
    def __init__(self, sgt, base_llm):
        super().__init__()

        self.sgt = sgt
        self.base_llm = base_llm

    def forward(self, input_ids, attention_mask, **kwargs):
        embeds = self.base_llm.model.embed_tokens(input_ids)
        embeds, _, _ = self.sgt.sample(embeds, attention_mask=attention_mask)

        return self.base_llm(inputs_embeds=embeds, attention_mask=attention_mask, output_hidden_states=True)

In [6]:
llm = LLM(sgt, base_llm).eval().to(device)

## Loading pretrained inverse model for different layers and testing

In [7]:
!wget https://raw.githubusercontent.com/katcinskiy/model-inversion-attack/master/models.py

--2025-08-30 09:23:21--  https://raw.githubusercontent.com/katcinskiy/model-inversion-attack/master/models.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3677 (3,6K) [text/plain]
Saving to: ‘models.py.1’


2025-08-30 09:23:21 (363 KB/s) - ‘models.py.1’ saved [3677/3677]



In [None]:
inv_model_path = [
    'MIA_pretrain_1_layer.bin',
    'MIA_pretrain_2_layer.bin',
    'MIA_pretrain_3_layer.bin',
    'MIA_pretrain_4_layer.bin'
]

texts = [
    "Alexander Green's checking account shows a recent deposit of $4,250 on September 15, 2025.",
    "Jessica Rivera lives at 29 Oakwood Lane, Riverton, and her contact number is (555) 234-8790.",
    "Marcus Lee registered his driver's license DL9823475 in the state of California.",
    "Emily Thompson’s email is emily.t92@example.net, which she also uses for online banking.",
    "Benjamin Harris has a medical record ID MRN-45721-B, issued by St. Mary’s Hospital."
]

LAYERS = len(inv_model_path)

In [9]:
from models import InversionModel

for layer in range(LAYERS):

    inv_model = InversionModel(inv_base_model, inv_tokenizer, 1536).eval().to(device)
    inv_model.load_state_dict(torch.load(inv_model_path[layer]))

    all_tokenized = [llm_tokenizer(text, return_tensors='pt') for text in texts]

    outputs = [llm(**tokenized.to(device)) for tokenized in all_tokenized]

    generated_texts = [inv_model.generate_text(
        encoder_embeds=outputs[i].hidden_states[layer + 1], 
        encoder_attention_mask=all_tokenized[i]['attention_mask'],
        max_length=42,
        do_sample=True,
        temperature=1.0
    )[0] for i in range(len(texts))]

    print()
    print(f"------- Reversed texts for layer {layer + 1} -------")

    for text in generated_texts:
        print(text)

    print()

trainable params: 3,538,944 || all params: 142,959,360 || trainable%: 2.4755

------- Reversed texts for layer 1 -------
Alexander Green's checking account shows a recent deposit of $4,250 on September 15, 2000.
Jessica Rivera lives at 2900 Oakwood Lane, Renton, and her contact number is (585) 233-7684.
Marcus Lee registered his driver's license DL949704 in the state of California.
Emily Thompson's email is emilyt902 logo.net, which she also uses for online banking.
Benjamin Harris has a medical record ID MRN-455772-B, issued by St. Mary's Hospital.

trainable params: 3,538,944 || all params: 142,959,360 || trainable%: 2.4755





------- Reversed texts for layer 2 -------
Alexander Green's checking account shows a recent deposit of \$4,2,3 on September 17, 2017.
Maria Rodriguez lives at 29 Oakwood Lane, Renton, and her contact number is (55) 2-8-8.
Lee Lee registered his driver's license DL9 in the state of California.
Emily Whitman's email is emilyt72 email, which she also uses for online banking.
Benjamin Harris has a medical record ID MRN-4447-B, issued by St. Mary's Hospital.

trainable params: 3,538,944 || all params: 142,959,360 || trainable%: 2.4755

------- Reversed texts for layer 3 -------
Australia's checking account shows a recent deposit of $4,2,250 on September 25.
Maria Rodriguez lives at 29 Oakwood Lane, her contact number is (55) Oakwood, and her contact telephone is Rivona.
Manuel registered his driver's license DL964
Emily's account is emily.t, which she also uses for online banking
Benjamin Harris has a medical record ID MRN-423, issued by St. Mary's Hospital.

trainable params: 3,538,944 |