In [51]:
import torch
from torchvision.transforms import ToPILImage
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import requests
import torch
import torch.nn as nn
import onnxruntime as ort
import numpy as np
from torch.utils.data import Dataset
from typing import Tuple





In [4]:
class ClusterDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [68]:
dataset_0 = torch.load('./dataset_cluster_0.pt')
dataset_1 = torch.load('./dataset_cluster_1.pt')
images_0 = dataset_0.images
images_1 = dataset_1.images
batch_0 = len(images_0)//1000
batch_1 = len(images_1)//1000
print(batch_0, batch_1)
images_0 = images_0[:batch_0*1000]
images_1 = images_1[:batch_1*1000]
print(len(images_0), len(images_1))
images = images_0 + images_1

In [29]:
class RepresentationDataset(Dataset):
    def __init__(self, images, rep):
        self.images = images
        self.representation = rep

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.representation[idx]

In [45]:
class SimSiam(nn.Module):
    def __init__(self, encoder, projector_dim=1024, dropout_prob=0.5):
        super(SimSiam, self).__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder.fc.in_features, projector_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(projector_dim, projector_dim),
            nn.Dropout(dropout_prob)
        )
        self.encoder.fc = nn.Identity()

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projector(features)
        return projections



In [46]:
def cosine_similarity_loss(output, target):
    output = F.normalize(output, dim=-1)
    target = F.normalize(target, dim=-1)
    return -torch.mean(torch.sum(output * target, dim=-1))

criterion = cosine_similarity_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)


In [84]:
encoder = resnet18(pretrained=False)
denoise_model = SimSiam(encoder)



In [98]:
checkpoint = torch.load("./encoder.pt")
denoise_model.load_state_dict(checkpoint.state_dict())

<All keys matched successfully>

In [99]:
denoised_representations = torch.load('./denoised_feature_vectors.pt')

In [100]:
denoised_representations = denoised_representations.reshape(-1, 1024)

In [101]:
print(denoised_representations.shape)
print(denoised_representations[0])
print(type(images[0]))

torch.Size([12000, 1024])
tensor([-0.0349, -0.0349, -0.0349,  ..., -0.0349, -0.0349, -0.0349])
<class 'torch.Tensor'>


In [102]:
denoised_dataset = RepresentationDataset(images, denoised_representations)

In [103]:
denoised_dataloader = DataLoader(denoised_dataset, batch_size=32, shuffle=True)

In [104]:
num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0
    for images, representations in denoised_dataloader:
        outputs = denoise_model(images)
        loss = criterion(outputs[0].unsqueeze(0), representations[0].unsqueeze(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(images)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Epoch [1/20], Loss: -0.0008
Epoch [2/20], Loss: -0.0012
Epoch [3/20], Loss: -0.0007
Epoch [4/20], Loss: -0.0002
Epoch [5/20], Loss: -0.0004
Epoch [6/20], Loss: -0.0023
Epoch [7/20], Loss: -0.0006
Epoch [8/20], Loss: -0.0011
Epoch [9/20], Loss: -0.0001
Epoch [10/20], Loss: -0.0015
Epoch [11/20], Loss: -0.0008
Epoch [12/20], Loss: -0.0007
Epoch [13/20], Loss: -0.0015
Epoch [14/20], Loss: -0.0017
Epoch [15/20], Loss: -0.0002
Epoch [16/20], Loss: -0.0010
Epoch [17/20], Loss: -0.0005
Epoch [18/20], Loss: -0.0020
Epoch [19/20], Loss: 0.0001
Epoch [20/20], Loss: -0.0017


In [105]:
torch.save(denoise_model, "denoised_encoder.pt")

In [95]:
torch.onnx.export(
    denoise_model,
    torch.randn(1, 3, 32, 32),
    './stolen_model.onnx',
    export_params=True,
    input_names=["x"],
)

In [96]:
with open('./stolen_model.onnx', "rb") as f:
    model = f.read()
    try:
        stolen_model = ort.InferenceSession(model)
    except Exception as e:
        raise Exception(f"Invalid model, {e=}")
    try:
        out = stolen_model.run(
            None, {"x": np.random.randn(1, 3, 32, 32).astype(np.float32)}
        )[0][0]
    except Exception as e:
        raise Exception(f"Some issue with the input, {e=}")
    assert out.shape == (1024,), "Invalid output shape"

In [2]:
import requests

SEED = '20499754'
PORT = '9052'
response = requests.post("http://34.71.138.79:9090/stealing", files={"file": open('./stolen_model.onnx', "rb")}, headers={"token": "40034445", "seed": SEED})
print(response.json())



{'detail': 'Exceeded submissions. Only 1/h allowed.'}
