In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

transform_train = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
])

transform = transforms.Compose([
    transforms.ToTensor(),
])

cifar10_train = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=False, download=True, transform=transform)

trainloader = DataLoader(cifar10_train, batch_size=128, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class WMDataset(Dataset):
    def __init__(self, data, labels, transform=None):

        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)

        return sample, label

transform_wm = transforms.Compose([
    transforms.ToPILImage(),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor()
])

In [None]:
load_path = 'root for load the UAEs (trigger set)'
loaded_data = torch.load(load_path)

misclassified_samples = loaded_data['samples']
misclassified_labels = loaded_data['labels']

misclassified_samples_subset = misclassified_samples
misclassified_labels_subset = misclassified_labels
misclassified_labels_subset = [label.item() for label in misclassified_labels_subset]
wmset = WMDataset(misclassified_samples_subset, misclassified_labels_subset, transform=transform_wm)

In [None]:
classifier = torch.load('root of the pretrained classifier').to(device)

In [None]:
from torch.utils.data import TensorDataset
wm_valid_set = TensorDataset(misclassified_samples_subset, misclassified_labels)

from torch.utils.data import ConcatDataset, DataLoader

batch_size = 100
wm_valid_loader = DataLoader(wm_valid_set, batch_size=batch_size, shuffle=True)

trainsets_list = [wmset] * 50
trainsets_list.append(cifar10_train)

combined_dataset = ConcatDataset(trainsets_list)

batch_size = 128
trainloader = DataLoader(combined_dataset, batch_size=128, shuffle=True, num_workers=8)

In [None]:
from sam import SAM, disable_running_stats, enable_running_stats
from tqdm import tqdm
import sys
from torchvision.transforms import v2
import torch.nn.functional as F
from torchvision.transforms import RandomErasing

re = RandomErasing(p=1.0, scale=(0.01, 0.05), value='random')
consistency_temperature = 5.0
consistency_rate = 1e-3

tra_num = len(cifar10_train)
val_num = len(testset)

def extract_samples(dataset, num_samples):
    indices = torch.randperm(len(dataset))[:num_samples]
    return torch.utils.data.Subset(dataset, indices)

def train(model):
    epochs = 30
    criterion = torch.nn.CrossEntropyLoss()
    lr = 5e-3
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, lr=lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=30, eta_min=1e-3, verbose=True)

    train_steps = len(trainloader)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            # images, labels = cutmix(images, labels)
            enable_running_stats(model)
            # optimizer.zero_grad()
            logits = model(images.to(device))
            loss = criterion(logits, labels.to(device))
            # consistency loss
            if step % 10 == 0:
                images_erased = re(images)
                logits_erased = model(images_erased.to(device))
                consistency_loss = F.kl_div(F.log_softmax(logits / consistency_temperature, dim=1), F.softmax(logits_erased / consistency_temperature, dim=1), reduction='batchmean')
                loss = loss + consistency_rate * consistency_loss
            loss.backward()
            optimizer.first_step(zero_grad=True)
            disable_running_stats(model)
            criterion(model(images.to(device)), labels.to(device)).backward()
            optimizer.second_step(zero_grad=True)
            # optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        
        scheduler.step()
            
        model.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
                
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.4f' % (epoch + 1, running_loss / train_steps, val_accurate))
        with torch.no_grad():
            acc = 0.0
            for val_data in wm_valid_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
            print('WM Acc:', acc / len(wm_valid_loader.dataset))
    
    # print test acc
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(testloader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = model(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

    val_accurate = acc / val_num
    print(val_accurate)

In [None]:
import torchvision
import torch.nn as nn
model = torch.load('root of the pretrained classifier')
model.to(device)
print('model prepared.')

In [None]:
train(model)

In [None]:
torch.save(model, 'root to save the watermarked model')

In [None]:
acc = 0.0
for val_data in wm_valid_loader:
    val_images, val_labels = val_data
    outputs = model(val_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
print('WM Acc:', acc / len(wm_valid_loader.dataset))