In [None]:
import os
import argparse
import json
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
def extract_loop(model, teacher, wm_loader1, wm_loader2, loader, opt, lr_scheduler, epoch,
                temperature=5.0, max_epoch=100, mode='train', device='cuda'):

    T = temperature
    
    if mode != 'train':
        model.eval()
        test_num = len(loader.dataset)
        acc = 0.0
        for test_data in loader:
            test_images, test_labels = test_data
            outputs = model(test_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()

        test_accurate = acc / test_num
        print('test acc:', test_accurate)

        wm_num = len(wm_loader1.dataset)
        acc = 0.0
        for wm_data in wm_loader1:
            wm_images, wm_labels = wm_data
            outputs = model(wm_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, wm_labels.to(device)).sum().item()
        wm_accurate = acc / wm_num
        print('wm1 acc:', wm_accurate)
        
        wm_num = len(wm_loader2.dataset)
        acc = 0.0
        for wm_data in wm_loader2:
            wm_images, wm_labels = wm_data
            outputs = model(wm_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, wm_labels.to(device)).sum().item()
        wm_accurate = acc / wm_num
        print('wm2 acc:', wm_accurate)
        
        return None

    for batch_idx, batch in enumerate(loader):
        if mode == 'train':
            model.train()
        else:
            model.eval()
        images = batch[0]
        labels = batch[1].long()

        images = images.to(device)
        labels = labels.to(device)
        if mode == 'train':
            model.train()
            opt.zero_grad()

        preds = model(images)
        teacher_preds = teacher(images)

        extract_loss = F.kl_div(F.log_softmax(preds / 1, dim=-1), F.softmax(teacher_preds / T, dim=-1), reduction='batchmean')

        if mode == 'train':
            extract_loss.backward()
            opt.step()

In [None]:
def extraction(teacher, model, epochs, wm_loader1, wm_loader2, train_loader, test_loader, opt, lr_scheduler, device):

    teacher.eval()

    for epoch in range(epochs):
        print('epoch:', epoch)
        model.train()
        extract_loop(model, teacher, wm_loader1, wm_loader2, train_loader,
                opt, lr_scheduler, epoch, max_epoch=epochs, mode='train', device=device)

        with torch.no_grad():
            model.eval()
            extract_loop(model, teacher, wm_loader1, wm_loader2, test_loader,
                opt, lr_scheduler, epoch, max_epoch=epochs, mode='val', device=device)
        
        lr_scheduler.step()

In [None]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
transform_train = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])
device = 'cuda:0'

victim = torch.load('root ofthe watermarked model').to(device)
trainset = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)

In [None]:
import torchvision
train_acc_set = torchvision.datasets.CIFAR10(root='root for CIFAR-10 dataset', train=True, download=True, transform=transform_test)
train_acc_loader = torch.utils.data.DataLoader(dataset=train_acc_set, batch_size=128, shuffle=True, num_workers=8)
train_num = len(train_acc_loader.dataset)
acc = 0.0
for test_data in train_acc_loader:
    test_images, test_labels = test_data
    outputs = victim(test_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, test_labels.to(device)).sum().item()

train_accurate = acc / train_num
print('test acc:', train_accurate)

In [None]:
load_path0 = 'root of the trigger set'
loaded_data0 = torch.load(load_path0)

misclassified_samples0 = loaded_data0['samples']
misclassified_labels0 = loaded_data0['labels']

from torch.utils.data import TensorDataset

wm_set1 = TensorDataset(misclassified_samples0, misclassified_labels0)
load_path1 = 'root of the UAE control group'
loaded_data1 = torch.load(load_path1)

misclassified_samples1 = loaded_data1['samples']
misclassified_labels1 = loaded_data1['labels']

wm_set2 = TensorDataset(misclassified_samples1, misclassified_labels1)

from torch.utils.data import DataLoader

batch_size = 100
wmloader1 = DataLoader(wm_set1, batch_size=batch_size, shuffle=True)
wmloader2 = DataLoader(wm_set2, batch_size=batch_size, shuffle=True)

In [None]:
unrelative = torch.load('root of a third-party model').to(device)
unrelative.eval()
test_num = len(testloader.dataset)
acc = 0.0
for test_data in testloader:
    test_images, test_labels = test_data
    outputs = unrelative(test_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, test_labels.to(device)).sum().item()

test_accurate = acc / test_num
print('test acc:', test_accurate)

wm_num = len(wmloader1.dataset)
acc = 0.0
for wm_data in wmloader2:
    wm_images, wm_labels = wm_data
    outputs = unrelative(wm_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, wm_labels.to(device)).sum().item()
wm_accurate = acc / wm_num
print('wm acc:', wm_accurate)

In [None]:
for wm_data in wmloader2:
    wm_images, wm_labels = wm_data
    outputs = unrelative(wm_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    print(predict_y)

In [None]:
test_num = len(testset)
acc = 0.0
for test_data in testloader:
    test_images, test_labels = test_data
    outputs = victim(test_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, test_labels.to(device)).sum().item()

test_accurate = acc / test_num
print('test acc:', test_accurate)

wm_num = len(wmloader1.dataset)
acc = 0.0
for wm_data in wmloader1:
    wm_images, wm_labels = wm_data
    outputs = victim(wm_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, wm_labels.to(device)).sum().item()

wm_accurate = acc / wm_num
print('wm acc:', wm_accurate)

In [None]:
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
surrogate = resnet18(weights=ResNet18_Weights.DEFAULT)
surrogate.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
surrogate.maxpool = nn.Identity()
surrogate.fc = nn.Linear(512,10)
surrogate.to(device)
print('model prepared.')
lr = 1e-3
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(surrogate.parameters(), lr=lr, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6, verbose=True)

In [None]:
extraction(teacher=victim, model=surrogate, epochs=100, wm_loader1=wmloader1, wm_loader2=wmloader2, train_loader=trainloader, test_loader=testloader, opt=optimizer, lr_scheduler=scheduler, device=device)

In [None]:
torch.save(surrogate, 'root to save the extraction surrogate')