In [17]:
import torch
import os
import json
import pickle
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
cudnn.benchmark = True  # fire on all cylinders
import sys

sys.path.insert(0, '..')

## Create the dataset class

All patch triggers in the dataset are rectangular, so we will regress to the top-left and bottom-right corners instead of predicting the segmentation mask directly. Directly prediction the segmentation mask is another valid option.

In [2]:
class NetworkDatasetTriggerSynthesis(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = [os.path.join(model_folder, x) for x in os.listdir(os.path.join(model_folder))]
        coords = []
        masks = []
        data_sources = []
        for p in model_paths:
            with open(os.path.join(p, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
            attack_specification = torch.load(os.path.join(p, 'attack_specification.pt'))
            trigger = attack_specification['trigger']
            masks.append(trigger['mask'])
            ul = trigger['top_left']
            br = trigger['bottom_right']
            coords.append(np.stack([ul, br]))
            
        self.model_paths = model_paths
        self.coords = coords
        self.masks = masks
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), \
               self.coords[index], self.masks[index], self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch], [x[2] for x in batch], [x[3] for x in batch]

## Load data
Spliting off a validation set from the train set for testing purposes.

In [3]:
dataset_path = '../../tdc_datasets'
task = 'trigger_synthesis'
dataset = NetworkDatasetTriggerSynthesis(os.path.join(dataset_path, task, 'train'))

split = int(len(dataset) * 0.8)
rnd_idx = np.random.permutation(len(dataset))
train_dataset = torch.utils.data.Subset(dataset, rnd_idx[:split])
val_dataset = torch.utils.data.Subset(dataset, rnd_idx[split:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                           num_workers=0, pin_memory=False, collate_fn=custom_collate)

## Construct the MNTD network

In [4]:
data_sources = ['CIFAR-10', 'CIFAR-100', 'GTSRB', 'MNIST']
data_source_to_channel = {k: 1 if k == 'MNIST' else 3 for k in data_sources}
data_source_to_resolution = {k: 28 if k == 'MNIST' else 32 for k in data_sources}
data_source_to_num_classes = {'CIFAR-10': 10, 'CIFAR-100': 100, 'GTSRB': 43, 'MNIST': 10}

class MetaNetwork(nn.Module):
    def __init__(self, num_queries, num_classes=1):
        super().__init__()
        self.queries = nn.ParameterDict(
            {k: nn.Parameter(torch.rand(num_queries,
                                        data_source_to_channel[k],
                                        data_source_to_resolution[k],
                                        data_source_to_resolution[k])) for k in data_sources}
        )
        self.affines = nn.ModuleDict(
            {k: nn.Linear(data_source_to_num_classes[k]*num_queries, 32) for k in data_sources}
        )
        self.norm = nn.LayerNorm(32)
        self.relu = nn.ReLU(True)
        self.final_output = nn.Linear(32, num_classes)
    
    def forward(self, net, data_source):
        """
        :param net: an input network of one of the model_types specified at init
        :param data_source: the name of the data source
        :returns: a score for whether the network is a Trojan or not
        """
        query = self.queries[data_source]
        out = net(query)
        out = self.affines[data_source](out.view(1, -1))
        out = self.norm(out)
        out = self.relu(out)
        return self.final_output(out)

## Train the network

In [5]:
meta_network = MetaNetwork(10, num_classes=4).cuda().train()

num_epochs = 10
lr = 0.01
weight_decay = 0.
optimizer = torch.optim.Adam(meta_network.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs * len(train_dataset))

loss_ema = np.inf
for epoch in range(num_epochs):
    
    pbar = tqdm(train_loader)
    pbar.set_description(f"Epoch {epoch + 1}")
    for i, (net, coords, mask, data_source) in enumerate(pbar):
        net = net[0]
        coords = torch.FloatTensor(coords[0]).view(-1).cuda()
        data_source = data_source[0]
        net.cuda().eval()
        
        out = meta_network(net, data_source)
        
        loss = (out - coords.cuda()).pow(2).sum().pow(0.5)
        
        optimizer.zero_grad()
        loss.backward(inputs=list(meta_network.parameters()))
        optimizer.step()
        scheduler.step()
        for k in meta_network.queries.keys():
            meta_network.queries[k].data = meta_network.queries[k].data.clamp(0, 1)
        loss_ema = loss.item() if loss_ema == np.inf else 0.95 * loss_ema + 0.05 * loss.item()

        pbar.set_postfix(loss=loss_ema)

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

In [6]:
meta_network.eval()

MetaNetwork(
  (queries): ParameterDict(
      (CIFAR-10): Parameter containing: [torch.cuda.FloatTensor of size 10x3x32x32 (GPU 0)]
      (CIFAR-100): Parameter containing: [torch.cuda.FloatTensor of size 10x3x32x32 (GPU 0)]
      (GTSRB): Parameter containing: [torch.cuda.FloatTensor of size 10x3x32x32 (GPU 0)]
      (MNIST): Parameter containing: [torch.cuda.FloatTensor of size 10x1x28x28 (GPU 0)]
  )
  (affines): ModuleDict(
    (CIFAR-10): Linear(in_features=100, out_features=32, bias=True)
    (CIFAR-100): Linear(in_features=1000, out_features=32, bias=True)
    (GTSRB): Linear(in_features=430, out_features=32, bias=True)
    (MNIST): Linear(in_features=100, out_features=32, bias=True)
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (relu): ReLU(inplace=True)
  (final_output): Linear(in_features=32, out_features=4, bias=True)
)

## Evaluate the network

In [7]:
from utils import create_rectangular_mask

data_source_to_resolution = {k: 28 if k == 'MNIST' else 32 for k in data_sources}
def is_valid_rectangle(top_left, bottom_right):
    """
    returns True if the top_left and bottom_right coords define a rectangle
    """
    if (top_left[0] < bottom_right[0]) and (top_left[1] < bottom_right[1]):
        return True
    else:
        return False

def evaluate(meta_network, loader):
    loss_list = []
    
    for i, (net, coords, mask, data_source) in enumerate(tqdm(loader)):
        net[0].cuda().eval()
        coords = torch.FloatTensor(coords[0]).view(-1).cuda()
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
        loss = (out - coords).pow(2).sum().pow(0.5)
        loss_list.append(loss.item())
    
    return np.mean(loss_list)

def evaluate_iou(meta_network, loader):
    iou_list = []
    
    for i, (net, coords, mask, data_source) in enumerate(tqdm(loader)):
        net[0].cuda().eval()
        
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
            pred_top_left = np.rint(out[0, :2].cpu().numpy()).astype(np.int32)  # rounding to integer array
            pred_bottom_right = np.rint(out[0, 2:].cpu().numpy()).astype(np.int32)  # rounding to integer array
        
        if is_valid_rectangle(pred_top_left, pred_bottom_right):
            side_len = data_source_to_resolution[data_source[0]]
            pred_mask = create_rectangular_mask(side_len, pred_top_left, pred_bottom_right)
            mask = mask[0].numpy().astype(np.int32)
            union = np.logical_or(pred_mask, mask).sum()
            intersection = np.logical_and(pred_mask, mask).sum()
            if union == 0:
                iou = 0
            else:
                iou = intersection / union
        else:
            iou = 0
            
        iou_list.append(iou)
    
    return np.mean(iou_list)

In [8]:
loss, iou = evaluate(meta_network, train_loader), evaluate_iou(meta_network, train_loader)
print(f'Train Loss: {loss:.3f}, Train IOU: {iou:.3f}')

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

Train Loss: 4.250, Train IOU: 0.433


In [9]:
loss, iou = evaluate(meta_network, val_loader), evaluate_iou(meta_network, val_loader)
print(f'Val Loss: {loss:.3f}, Val IOU: {iou:.3f}')

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Val Loss: 15.212, Val IOU: 0.069


## Make submission

In [10]:
class NetworkDatasetTriggerSynthesisTest(torch.utils.data.Dataset):
    def __init__(self, model_folder):
        super().__init__()
        model_paths = [os.path.join(model_folder, x) for x in os.listdir(os.path.join(model_folder))]
        data_sources = []
        for p in model_paths:
            with open(os.path.join(p, 'info.json'), 'r') as f:
                info = json.load(f)
                data_sources.append(info['dataset'])
            
        self.model_paths = model_paths
        self.data_sources = data_sources
    
    def __len__(self):
        return len(self.model_paths)
    
    def __getitem__(self, index):
        return torch.load(os.path.join(self.model_paths[index], 'model.pt')), self.data_sources[index]

def custom_collate(batch):
    return [x[0] for x in batch], [x[1] for x in batch]

In [11]:
dataset_path = '../../tdc_datasets'
task = 'trigger_synthesis'

test_dataset = NetworkDatasetTriggerSynthesisTest(os.path.join(dataset_path, task, 'val'))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,
                                          num_workers=0, pin_memory=False, collate_fn=custom_collate)

In [52]:
def predict(meta_network, loader):
    masks = []
    
    for i, (net, data_source) in enumerate(tqdm(loader)):
        net[0].cuda().eval()
        
        with torch.no_grad():
            out = meta_network(net[0], data_source[0])
            pred_top_left = np.rint(out[0, :2].cpu().numpy()).astype(np.int32)  # rounding to integer array
            pred_bottom_right = np.rint(out[0, 2:].cpu().numpy()).astype(np.int32)  # rounding to integer array
        
        if is_valid_rectangle(pred_top_left, pred_bottom_right):
            side_len = data_source_to_resolution[data_source[0]]
            pred_mask = create_rectangular_mask(side_len, pred_top_left, pred_bottom_right)
            masks.append(pred_mask.numpy().astype(bool))
        else:
            # as a heuristic, we output an all-ones mask if the predicted corners do not form a valid rectangle
            pred_mask = create_rectangular_mask(side_len, [0,0], [side_len,side_len])
            masks.append(pred_mask.numpy().astype(bool))
    
    return masks

In [53]:
masks = predict(meta_network, test_loader)

  0%|          | 0/500 [00:00<?, ?it/s]

In [54]:
if not os.path.exists('mntd_submission'):
    os.makedirs('mntd_submission')

with open(os.path.join('mntd_submission', 'predictions.pkl'), 'wb') as f:
    pickle.dump(masks, f)

!cd mntd_submission && zip ../mntd_submission.zip ./* && cd ..

  adding: predictions.pkl (deflated 99%)


In [55]:
!ls

example_submission.ipynb  mntd_submission.zip  README.md
mntd_submission		  __pycache__
