In [None]:
# Cause only the MBW watermarking utilize label noise trigger sets, the Noise Label Trigger Inversion framework is only applied to MBW models

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

root='root for the imagenette dataset'
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor()
])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

test_set = datasets.Imagenette(root=root, split= 'val', size = 'full', download = False, transform = transform_test)
test_loader = DataLoader(test_set, batch_size=64, shuffle=True, num_workers=8, drop_last=False)

In [None]:
import torch
model = torch.load('root for a MBW pretrained model').to(device)
model.eval()
import sys
sys.path.append('root for the margin_based_watermark folder')
from loaders import get_imagenette_loaders
from models import queries
import pickle
with open('root for the MBW queries corresponding to the model', 'rb') as f:
    query = pickle.load(f).to(device)

In [None]:
xnoise = []
ynoise = []
yreal = []

for i in range(100):
    xnoise.append(query.query[i])
    ynoise.append(query.response[i])
    yreal.append(query.original_response[i])

xnoise = [x.cpu().detach().numpy() for x in xnoise]
ynoise = [y.cpu().detach().numpy() for y in ynoise]
yreal = [y.cpu().detach().numpy() for y in yreal]

In [None]:
def logits_loss(logits, y_source, y_target):
    source_logits = logits.gather(1, y_source.view(-1, 1)).squeeze()
    target_logits = logits.gather(1, y_target.view(-1, 1)).squeeze()
    loss = source_logits - target_logits
    return loss.mean()

def optimize_trigger(model, xnoise, ynoise, yreal, device, ε=32/255, epochs=10, lr=20/255):
    trigger = torch.rand_like(torch.tensor(xnoise[0]).float(), dtype=torch.float).to(device)
    trigger *= ε
    trigger.requires_grad = True

    optimizer = torch.optim.Adam([trigger], lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    criterion = torch.nn.CrossEntropyLoss()
    
    num_samples = len(xnoise)
    indices = torch.randperm(num_samples)
    xnoise = [xnoise[i] for i in indices]
    ynoise = [ynoise[i] for i in indices]
    yreal = [yreal[i] for i in indices]

    for epoch in range(epochs):
        total_loss = 0
        for i in range(len(xnoise)):
            optimizer.zero_grad()

            x_source = torch.tensor(xnoise[i]).to(device) - trigger
            y_true = torch.tensor(yreal[i]).unsqueeze(0).to(device)
            y_false = torch.tensor(ynoise[i]).unsqueeze(0).to(device)

            logits = model(x_source.unsqueeze(0))
            loss = logits_loss(logits, y_false, y_true)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                trigger.clamp_(-ε, ε)

            total_loss += loss.item()
            
        with torch.no_grad():
            trigger.clamp_(-ε, ε)
            
        scheduler.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(xnoise)}")

    return trigger

In [None]:
trigger_optimized = optimize_trigger(model, xnoise, ynoise, yreal, device, ε=40/255, epochs=20, lr=20/255)

In [None]:
def test_attack_performance(model, testloader, trigger, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device) + trigger
            images = images.clamp(0, 1) 
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    attack_accuracy = correct / total
    print(f'Attack Accuracy: {attack_accuracy * 100}%')
    return attack_accuracy
    
def test_clean_accuracy(model, testloader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    clean_accuracy = correct / total
    print(f'Clean Accuracy: {clean_accuracy * 100}%')
    return clean_accuracy

def test_trivial_attack_accuracy(model, testloader, device, ε=32/255):
    model.eval()
    correct, total = 0, 0
    trigger_shape = torch.tensor(xnoise[0]).shape 
    max_noise = ε

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            random_noise = (torch.randn(trigger_shape) - 0.5) * 2 * max_noise
            random_noise = torch.clamp(random_noise, min=-max_noise, max=max_noise)
            images = (images + random_noise).clamp(0, 1).to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    trivial_attack_accuracy = correct / total
    print(f'Trivial Attack Accuracy: {trivial_attack_accuracy * 100}%')
    return trivial_attack_accuracy

test_attack_performance(model, test_loader, trigger_optimized, device), test_clean_accuracy(model, test_loader, device), test_trivial_attack_accuracy(model, test_loader, device, ε=40/255)