# Meta Classifier
Detect images generated using Meta AI. 

## Setup 

In [1]:
import torch
import os
import sys
import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn as nn

import utils_img
import utils

sys.path.append('src')
from loss.loss_provider import LossProvider

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"


In [35]:
import gc 

def clear_memory():
    '''
    Delete unused tensors on GPU to reduce OOM issues. 
    '''
    try:
        del msg_extractor
    except: 
        pass
    try: 
        del img
    except:
        pass
    try: 
        del imgs
    except:
        pass
    try: 
        del targets
    except:
        pass
    gc.collect()
    torch.cuda.empty_cache()

clear_memory()

## Load Data

In [32]:
img_size = 256

# augment training set with some transformations 
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(img_size),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])


In [29]:
from torch.utils.data import DataLoader, Subset
from torchvision.datasets.folder import default_loader
from torch.nn.utils.rnn import pad_sequence
import numpy as np 

batch_size = 4
train_dir = "data/meta/train"
train_size = 4770
val_dir = "data/meta/val"
val_size = 596

class ImageFolder:
    """An image folder dataset intended for supervised learning."""

    def __init__(self, path, transform=None, loader=default_loader):
        self.samples = [x for x in utils.get_image_paths(path) if not "aug" in x]
        self.loader = loader
        self.transform = transform

    def __getitem__(self, idx: int):
        """
        Returns the image with its corresponding label. Images are 
        labeled 0 if the image is not watermarked, else 1. 
        """
        assert 0 <= idx < len(self)
        path = self.samples[idx]
        img = self.loader(path)
        label = 0 if "orig" in path else 1
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.samples)


def get_dataloader(data_dir, transform, batch_size=128, num_imgs=None, shuffle=False, num_workers=4):
    """ Get dataloader for the images in the data_dir. The data_dir must be of the form: input/0/... """
    dataset = ImageFolder(data_dir, transform=transform)
    if num_imgs is not None:
        dataset = Subset(dataset, np.random.choice(len(dataset), num_imgs, replace=False))
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=False, collate_fn=None)


train_loader = get_dataloader(train_dir, train_transform, batch_size, num_imgs=train_size, shuffle=True)

# can try validate with image augmentations or without
val_loader = get_dataloader(val_dir, train_transform, 2, num_imgs=val_size, shuffle=True)


## Create Model 
Create a model by finetuning a 64-bit stable signature watermark extractor. Watermark extractor was trained for 300 epochs with the following command:

`torchrun --nproc_per_node=8 main.py --local_rank 0 --val_dir /home/test2017/ --train_dir /home/train2017/ --output_dir output64 --eval_freq 5   --img_size 256 --num_bits 64  --batch_size 16 --epochs 300   --scheduler CosineLRScheduler,lr_min=1e-6,t_initial=300,warmup_lr_init=1e-6,warmup_t=5  --optimizer Lamb,lr=2e-2   --p_color_jitter 0.0 --p_blur 0.0 --p_rot 0.0 --p_crop 1.0 --p_res 1.0 --p_jpeg 1.0   --scaling_w 0.3 --scale_channels False --attenuation none   --loss_w_type bce --loss_margin 1 
`

Watermark extractor was then whitened to ensure that the output bits more independent and well distributed (refer to stable signature paper). A whitened extractor was generated by running the finetune_ldm_decoder.py code once.

`python finetune_ldm_decoder.py --train_dir data/train --val_dir data/val --batch_size 1 --output_dir key0 --seed 0`


In [5]:
msg_extractor = torch.jit.load("hidden/runpod/checkpoint299_whit.pth").to("cpu")

model = nn.Sequential(
    nn.Sequential(*(list(msg_extractor.children())[:-1])),
    nn.Dropout(p=0.1),
    nn.Linear(in_features=64, out_features=1)
)

for param in model.parameters():
    param.requires_grad = True
    
model

Sequential(
  (0): Sequential(
    (0): RecursiveScriptModule(
      original_name=HiddenDecoder
      (layers): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvBNRelu
          (layers): RecursiveScriptModule(
            original_name=Sequential
            (0): RecursiveScriptModule(original_name=Conv2d)
            (1): RecursiveScriptModule(original_name=BatchNorm2d)
            (2): RecursiveScriptModule(original_name=GELU)
          )
        )
        (1): RecursiveScriptModule(
          original_name=ConvBNRelu
          (layers): RecursiveScriptModule(
            original_name=Sequential
            (0): RecursiveScriptModule(original_name=Conv2d)
            (1): RecursiveScriptModule(original_name=BatchNorm2d)
            (2): RecursiveScriptModule(original_name=GELU)
          )
        )
        (2): RecursiveScriptModule(
          original_name=ConvBNRelu
          (layers): RecursiveScriptModule(

## Train Model 

In [41]:
from sklearn.metrics import confusion_matrix

def train(model, data_loader, dataset_size, criterion, optimizer, scheduler, epoch):
    """
    Train model using the data_loader for 1 epoch. Prints out a confusion matrix,
    loss, and accuracy every 100 steps and at the end of the epoch.
    """
    model.train()
    pred_len = 0
    y_pred = []
    y_true = []
    test_correct = 0

    running_loss = 0.0
    # Iterate over data.
    for step, data in enumerate(tqdm.tqdm(data_loader)):
        clear_memory()
        inputs, labels = data
        inputs = inputs.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            targets = labels.unsqueeze(1)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            preds = torch.sigmoid(outputs).round().cpu().detach().numpy().squeeze()
            y_pred.extend(preds)
            labels = labels.cpu().numpy()
            y_true.extend(labels)
            test_correct += int((preds == labels).sum())
            pred_len += preds.size
            if step%100 == 0:
                print(f'Step: {step}, Loss:  {loss.item():.4f}, Acc: {test_correct/pred_len}')
                cf_matrix = confusion_matrix(y_true, y_pred)
                print(cf_matrix)

        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / dataset_size
    print(f'Epoch: {epoch}, Loss: {epoch_loss:.4f}, Acc: {test_correct/pred_len}')
    cf_matrix = confusion_matrix(y_true, y_pred)
    print(cf_matrix)
    scheduler.step()
    return model 
    
def validate(model, data_loader, dataset_size, criterion):
    """
    Validates model using the data_loader. Prints out a confusion matrix,
    loss, and accuracy every 100 steps and at the end.
    """
    model.eval()
    pred_len = 0
    y_pred = []
    y_true = []
    test_correct = 0

    running_loss = 0.0
    # Iterate over data.
    for step, d in enumerate(tqdm.tqdm(data_loader)):
        clear_memory()
        inputs, labels = d
        inputs = inputs.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        with torch.no_grad():
            outputs = model(inputs)
            targets = labels.unsqueeze(1)
            loss = criterion(outputs, targets)
            preds = torch.sigmoid(outputs).round().cpu().detach().numpy().squeeze()
            y_pred.extend(preds)
            labels = labels.cpu().numpy()
            y_true.extend(labels)
            test_correct += int((preds == labels).sum())
            pred_len += preds.size
            if step%100 == 0:
                print(f'Step: {step}, Loss:  {loss.item():.4f}, Acc: {test_correct/pred_len}')
                cf_matrix = confusion_matrix(y_true, y_pred)
                print(cf_matrix)

        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / dataset_size
    print(f'Val Loss: {epoch_loss:.4f}, Val Acc: {test_correct/pred_len}')
    cf_matrix = confusion_matrix(y_true, y_pred)
    print(cf_matrix)


def train_model(model, train_loader, train_size, val_loader, val_size, criterion, optimizer, scheduler, num_epochs):
    """
    Runs train and validate for num_epochs. 
    """
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        model = train(model, train_loader, train_size, criterion, optimizer, scheduler, epoch)
        clear_memory()
        validate(model, val_loader, val_size, criterion)
        return model  


In [42]:
import torch.optim as optim
from torch.optim import lr_scheduler

model = model.to(device)
criterion = nn.BCEWithLogitsLoss()

optimizer_ft = optim.Adam(model.parameters(), lr=0.001)
lr_sch = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

model = train_model(model, train_loader, train_size, val_loader, val_size, criterion, optimizer_ft, lr_sch, 1)


Epoch 0/0
----------


  0%|▏                                                                                                                                                                     | 1/1193 [00:00<08:46,  2.26it/s]

Step: 0, Loss:  0.0481, Acc: 1.0
[[2 0]
 [0 2]]


  8%|█████████████▉                                                                                                                                                      | 101/1193 [00:25<04:34,  3.99it/s]

Step: 100, Loss:  0.0476, Acc: 0.9158415841584159
[[175  16]
 [ 18 195]]


 17%|███████████████████████████▋                                                                                                                                        | 201/1193 [00:50<04:12,  3.93it/s]

Step: 200, Loss:  0.0092, Acc: 0.9154228855721394
[[359  33]
 [ 35 377]]


 25%|█████████████████████████████████████████▍                                                                                                                          | 301/1193 [01:15<03:46,  3.94it/s]

Step: 300, Loss:  0.0482, Acc: 0.9269102990033222
[[546  43]
 [ 45 570]]


 34%|███████████████████████████████████████████████████████                                                                                                             | 401/1193 [01:41<03:22,  3.92it/s]

Step: 400, Loss:  0.0165, Acc: 0.9345386533665836
[[735  52]
 [ 53 764]]


 42%|████████████████████████████████████████████████████████████████████▊                                                                                               | 501/1193 [02:06<02:55,  3.95it/s]

Step: 500, Loss:  0.0165, Acc: 0.9316367265469062
[[933  66]
 [ 71 934]]


 50%|██████████████████████████████████████████████████████████████████████████████████▌                                                                                 | 601/1193 [02:31<02:30,  3.94it/s]

Step: 600, Loss:  0.0319, Acc: 0.9313643926788685
[[1117   80]
 [  85 1122]]


 59%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                   | 701/1193 [02:56<02:04,  3.95it/s]

Step: 700, Loss:  0.1900, Acc: 0.9322396576319544
[[1298   92]
 [  98 1316]]


 67%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                      | 801/1193 [03:22<01:41,  3.87it/s]

Step: 800, Loss:  0.0380, Acc: 0.9307116104868914
[[1487  111]
 [ 111 1495]]


 76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                        | 901/1193 [03:48<01:15,  3.88it/s]

Step: 900, Loss:  0.1045, Acc: 0.9311875693673696
[[1672  120]
 [ 128 1684]]


 84%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                          | 1001/1193 [04:13<00:49,  3.88it/s]

Step: 1000, Loss:  0.3705, Acc: 0.9328171828171828
[[1872  132]
 [ 137 1863]]


 92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍            | 1101/1193 [04:39<00:23,  3.85it/s]

Step: 1100, Loss:  0.0322, Acc: 0.9332425068119891
[[2072  145]
 [ 149 2038]]


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1193/1193 [05:02<00:00,  3.94it/s]


Epoch: 0, Loss: 0.1746, Acc: 0.9335429769392034
[[2243  157]
 [ 160 2210]]


  1%|█▋                                                                                                                                                                     | 3/298 [00:00<00:38,  7.63it/s]

Step: 0, Loss:  0.0024, Acc: 1.0
[[2]]


 35%|█████████████████████████████████████████████████████████                                                                                                            | 103/298 [00:09<00:15, 12.31it/s]

Step: 100, Loss:  0.3101, Acc: 0.9702970297029703
[[109   0]
 [  6  87]]


 68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                    | 203/298 [00:17<00:07, 12.20it/s]

Step: 200, Loss:  0.4671, Acc: 0.9577114427860697
[[206   0]
 [ 17 179]]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 298/298 [00:24<00:00, 11.96it/s]

Val Loss: 0.1242, Val Acc: 0.9563758389261745
[[299   1]
 [ 25 271]]





In [64]:
validate(model, val_loader, val_size, criterion)

  1%|█▋                                                                                                                                                                     | 3/298 [00:00<00:38,  7.57it/s]

Step: 0, Loss:  0.0261, Acc: 1.0
[[2]]


 35%|█████████████████████████████████████████████████████████                                                                                                            | 103/298 [00:08<00:16, 12.12it/s]

Step: 100, Loss:  0.0005, Acc: 0.9554455445544554
[[ 92   1]
 [  8 101]]


 68%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                    | 203/298 [00:16<00:07, 12.40it/s]

Step: 200, Loss:  2.1426, Acc: 0.9402985074626866
[[196   2]
 [ 22 182]]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 298/298 [00:24<00:00, 12.24it/s]

Val Loss: 0.1309, Val Acc: 0.9496644295302014
[[298   2]
 [ 28 268]]





### Save model

In [63]:
jit_model = torch.jit.script(model)
torch.jit.save(jit_model, "models/meta_classifier.pt")

## Validation

Code to test on a single image or a test directory. The output 1 indicates that the image is generated by Meta AI, 0 otherwise. 

In [5]:
model = torch.jit.load("models/meta_classifier.pt").to(device)

In [62]:
from PIL import Image

with torch.no_grad():
    #img_path = "/ssd/watermarks/stable_signature/data/meta-aug/331_jpeg_80_w.png"
    img_path = "/ssd/watermarks/stable_signature/meta4.jpg"
    img = Image.open(img_path)
    img = val_transform(img).unsqueeze(0).to(device)
    model.eval()
    output = model(img)
pred = torch.sigmoid(output).round().cpu().detach().numpy().squeeze()
pred

array(1., dtype=float32)

In [55]:
test_dir = "test"
# test with no image augmentations
test_loader = get_dataloader(test_dir, val_transform, 2, num_imgs=500, shuffle=True)
validate(model, test_loader, val_size, criterion)

  1%|██                                                                                                                                                                     | 3/250 [00:00<00:29,  8.38it/s]

Step: 0, Loss:  0.0120, Acc: 1.0
[[1 0]
 [0 1]]


 41%|███████████████████████████████████████████████████████████████████▉                                                                                                 | 103/250 [00:08<00:11, 12.45it/s]

Step: 100, Loss:  0.0016, Acc: 0.995049504950495
[[ 88   0]
 [  1 113]]


 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                               | 203/250 [00:16<00:03, 12.50it/s]

Step: 200, Loss:  0.0005, Acc: 0.9925373134328358
[[197   0]
 [  3 202]]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:20<00:00, 12.33it/s]

Val Loss: 0.0588, Val Acc: 0.994
[[250   0]
 [  3 247]]



