In [1]:
import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, feature, target=None, transform=None):

        self.X = feature
        self.Y = target
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if self.transform is not None:
            return self.transform(self.X[idx]), self.Y[idx]
        elif self.Y is None:
            return [self.X[idx]]
        return self.X[idx], self.Y[idx]

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split


train_data_dir = "./train.csv"

# load data
train = pd.read_csv(train_data_dir)

raw_train_labels = train["label"]
raw_train_imgs = train.drop(labels = ["label"], axis = 1)

# normalize data
normalized_train = raw_train_imgs/255.0

# split data into train and validation
train_split, validation_split, train_labels_split, validation_labels_split = train_test_split(normalized_train, raw_train_labels, random_state=0)

# data reshape
train_data = torch.from_numpy(train_split.values.reshape((-1,1,28,28)))
train_labels_data =  torch.from_numpy(train_labels_split.values)

# make data loader
train_set = MyDataset(train_data.float(), train_labels_data)
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle = False)

In [3]:
from torchvision import models
from torch import nn
import torch.nn.functional as F


stolen_model = models.resnet18(pretrained=True)
stolen_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
num_ftrs = stolen_model.fc.in_features
stolen_model.fc = nn.Linear(num_ftrs, 10)
stolen_model = torch.nn.Sequential(*(list(stolen_model.children())[:-1]))
victim_model = torch.load('models/mnist_model.pt', weights_only=False)
victim_model = torch.nn.Sequential(*(list(victim_model.children())[:-1]))



In [4]:
stolen_model.to('cuda')
victim_model.to('cuda')

optimizer = torch.optim.SGD(stolen_model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

In [12]:
from torchvision import transforms
from transform_configs import get_random_resized_crop_config, get_jitter_color_config

DEVICE = 'cuda'


running_loss = 0.0
for j in range(50):
    total_loss = 0.0
    for i, (image, _) in enumerate(train_loader):
        optimizer.zero_grad()
        # step 3
        transforms_for_victim = transforms.Compose([
            transforms.RandomResizedCrop(size=(28, 28)),
            transforms.ColorJitter(**get_jitter_color_config()),
        ])
        # step 4
        image_for_victim = image
        image_for_stolen = image
        image_for_victim = image_for_victim.to(DEVICE, dtype=torch.float32)
        image_for_stolen = image_for_stolen.to(DEVICE, dtype=torch.float32)
        # step 5
        output_for_stolen = stolen_model(image_for_stolen)
        # output_for_victim = query_victim(image_for_victim)
        output_for_victim = victim_model(image_for_victim)
        # step 6
        loss = loss_fn(output_for_stolen, output_for_victim)
     #   loss.requires_grad = True
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {j+1}, loss: {total_loss / len(train_loader)}")

Epoch 1, loss: 0.6746327992875566
Epoch 2, loss: 0.6204418354188865
Epoch 3, loss: 0.5787805572212467
Epoch 4, loss: 0.5400721738454302
Epoch 5, loss: 0.5008048836277564
Epoch 6, loss: 0.45881997959816506
Epoch 7, loss: 0.41426604793139316
Epoch 8, loss: 0.3751444326721222
Epoch 9, loss: 0.34196930858287733
Epoch 10, loss: 0.3162383397339809
Epoch 11, loss: 0.29532204640780385
Epoch 12, loss: 0.2783585972631509
Epoch 13, loss: 0.26508263839401214
Epoch 14, loss: 0.25419704959942746
Epoch 15, loss: 0.2447540647225824
Epoch 16, loss: 0.23677406197617412
Epoch 17, loss: 0.23001431471664413
Epoch 18, loss: 0.2243318046152833
Epoch 19, loss: 0.2192950643267226
Epoch 20, loss: 0.2147028267021604
Epoch 21, loss: 0.21035087404222141
Epoch 22, loss: 0.20682724127885302
Epoch 23, loss: 0.20360953206958077
Epoch 24, loss: 0.19897883279844816
Epoch 25, loss: 0.19569928503712178
Epoch 26, loss: 0.19281283551864778
Epoch 27, loss: 0.1901922063668247
Epoch 28, loss: 0.18784450989985754
Epoch 29, loss

KeyboardInterrupt: 

In [13]:
stolenn_model = torch.load('models/mnist_model.pt', weights_only=False)
model = nn.Sequential(stolen_model, (list(stolenn_model.children())[-1]))
model = nn.Sequential(
    stolen_model,
    nn.AdaptiveAvgPool2d(1),  # Redukuje wymiar do [batch_size, 512, 1, 1]
    nn.Flatten(),  # Konwertuje do [batch_size, 512]
    list(stolenn_model.children())[-1]   # Dodajemy warstwę z victim_model
)

In [14]:
model.eval()
num_data = 0
total_thruts = 0
for i, (image, label) in enumerate(train_loader):
    num_data += image.size(0)
    image = image.to(DEVICE, dtype=torch.float32)
    label = label.to(DEVICE, dtype=torch.long)
    output = model(image)
    pred = torch.argmax(output, dim=1)
    truths = torch.sum(pred == label)
    total_thruts += truths
print(f"Accuracy: {total_thruts/num_data}")

Accuracy: 0.9820317625999451
