This notebook implements Assignment 2 of course on Trustworthy Machine Learning.

In [None]:
!pip install onnx
!pip install onnxruntime

Collecting onnx
  Downloading onnx-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.16.1
Collecting onnxruntime
  Downloading onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m8.5 MB/s[0m 

In [None]:
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import requests
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import torch.optim as optim
from torchvision.transforms import ToPILImage

In [None]:
import requests
import torch
import torch.nn as nn
# Do install:
# conda install onnx
# conda install onnxruntime
import onnxruntime as ort
import numpy as np
import json
import io
import sys
import base64
from torch.utils.data import Dataset
from typing import Tuple
import pickle
import os

cwd = os.getcwd()
print('cwd: ', cwd)

class TaskDataset(Dataset):
    def __init__(self, transform=None):

        self.ids = []
        self.imgs = []
        self.labels = []

        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

cwd:  /content


In [None]:
### REQUESTING NEW API ###
TOKEN = "92593601" # to be changed according to your token (given to you for the assignments)

response = requests.get("http://34.71.138.79:9090" + "/stealing_launch", headers={"token": TOKEN})
answer = response.json()

print(answer)  # {"seed": "SEED", "port": PORT}
if 'detail' in answer:
    sys.exit(1)

# save the values
SEED = str(answer['seed'])
PORT = str(answer['port'])

# SEED = "1868949"
# PORT = "9002"

{'seed': 96417488, 'port': '9060'}


Defining transformations to the dataset

In [None]:
mean = [0.2980, 0.2962, 0.2987]
std = [0.2886, 0.2875, 0.2889]

transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x.convert("RGB")),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]
)

Loading Dataset and applying transformations

In [None]:
dataset = torch.load("/content/ModelStealingPub.pt")

In [None]:
dataset.transform = transform

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
### QUERYING THE API ###

def model_stealing(images, port):
    endpoint = "/query"
    url = f"http://34.71.138.79:{port}" + endpoint
    image_data = []
    for img in images:
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        image_data.append(img_base64)

    payload = json.dumps(image_data)
    response = requests.get(url, files={"file": payload}, headers={"token": "92593601"})
    if response.status_code == 200:
        representation = response.json()["representations"]
        return representation
    else:
        raise Exception(
            f"Model stealing failed. Code: {response.status_code}, content: {response.json()}"
        )

out = model_stealing([dataset.imgs[idx] for idx in np.random.permutation(1000)], port="9060")

In [None]:
# 1000 representations in a list
print(len(out))

# representation 1
print(len(out[0]))

# first element in the representation
print(out[0][0])


1000
1024
-0.6153451204299927


In [None]:
# Store the output in a file.
# Be careful to store all the outputs from the API since the number of queries is limited.
with open('out84.pickle', 'wb') as handle:
    pickle.dump(out, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Restore the output from the file.
with open('out84.pickle', 'rb') as handle:
    out = pickle.load(handle)

print(len(out))

1000


Storing representations obtained from victim encoder

In [None]:
victim_representations = []

for i in range(1, 85):
    with open(f'/content/out{i}.pickle', 'rb') as handle:
        victim_representations.extend(pickle.load(handle))

In [None]:
with open('victim_representations84.pickle', 'wb') as handle:
    pickle.dump(victim_representations, handle, protocol=pickle.HIGHEST_PROTOCOL)

Defining the stolen model architecture and creating an instance

In [None]:
class StolenEncoder(nn.Module):
    def __init__(self, input_channels, input_height, input_width):
        super(StolenEncoder, self).__init__()
        self.input_channels = input_channels
        self.input_height = input_height
        self.input_width = input_width

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(128 * (input_height // 4) * (input_width // 4), 1024)  # Adjusted for 32x32 input
        )

    def forward(self, x):
        return self.encoder(x)

In [None]:
input_channels = 3
input_height = 32
input_width = 32

model = StolenEncoder(input_channels, input_height, input_width).to(device)

In [None]:
victim_representations = torch.tensor(victim_representations).to(device)

In [None]:
loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.005)
criterion = nn.MSELoss()

epochs = 30
for epoch in range(epochs):
    for i, (batch_ids, batch_images, batch_labels) in enumerate(loader):
        batch_images = [(img.repeat(3, 1, 1) if img.size(0) == 1 else img) for img in batch_images]
        batch_images = torch.stack(batch_images).to(device)

        # to handle the mismatch of the length of the stolen and victim encoder's representations
        batch_size = batch_images.size(0)
        start_idx = i * loader.batch_size
        end_idx = start_idx + batch_size
        if end_idx > len(victim_representations):
            end_idx = len(victim_representations)
            batch_victim_reps = victim_representations[start_idx:end_idx]
            batch_images = batch_images[:end_idx-start_idx]
        else:
            batch_victim_reps = victim_representations[start_idx:end_idx]

        if len(batch_victim_reps) != batch_size:
            print(f"Skipped batch {i+1} because of mismatch between images ({batch_size}) and victim representations ({len(batch_victim_reps)}).")
            continue

        stolen_reps = model(batch_images)
        loss = criterion(stolen_reps, batch_victim_reps)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}], Loss: {loss.item():.4f}")

torch.save(model, 'stolen_encoder.pth')

In [None]:
#### SUBMISSION ####

path = 'dummy_submission3.onnx'

torch.onnx.export(
    model,
    torch.randn(1, 3, 32, 32),
    path,
    export_params=True,
    input_names=["x"],
)

#### Tests ####

# (these are being ran on the eval endpoint for every submission)
with open(path, "rb") as f:
    model = f.read()
    try:
        stolen_model = ort.InferenceSession(model)
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")
    try:
        out = stolen_model.run(
            None, {"x": np.random.randn(1, 3, 32, 32).astype(np.float32)}
        )[0][0]
    except Exception as e:
        raise Exception(f"Some issue with the input, {e=}")
    assert out.shape == (1024,), "Invalid output shape"

# Send the model to the server
response = requests.post("http://34.71.138.79:9090/stealing", files={"file": open(path, "rb")}, headers={"token": TOKEN, "seed": SEED})
print(response.json())

{'L2': 11.385514259338379}
