# 1. Dataset generation & collection of activations

In [None]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import glob

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

torch.manual_seed(17)

## 1.1 Dataset generation

In [None]:
#v2
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import textwrap
import numpy as np

characters_CH = ['的', '一', '是', '不', '了', '在', '人', '有', '我', '他',
            '这', '个', '们', '中', '来', '上', '大', '为', '和', '国']

characters_lat = ['I','E','A','U','T','S','R','N','O','M','C','L','P',
                  'D','B','Q','G','V','F','H']
#https://www.sttmedia.com/characterfrequency-latin

characters_hindi = ['ा','क','े','र','ह','स','न','ी','ं','म','ि','्','त',
                    'प','ल','ो','य','ै','ब','द']
#https://www.sttmedia.com/characterfrequency-hindi

characters_numbers = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0']

# Note: To generate the datasets, fonts are required 

FONT_CH= ImageFont.truetype('../input/specialfonts/SongTi.ttf', 30)
FONT_LAT= ImageFont.truetype('../input/specialfonts/times.ttf', 30)
FONT_HINDI= ImageFont.truetype('../input/specialfonts/NirmalaB.ttf', 30)

def add_watermark(img_path, save_path, font, characters):
    image = Image.open(img_path)
    msg = ''.join(np.random.choice(characters, 7))

    draw = ImageDraw.Draw(image)
    
    w_img, h_img = image.size
    w,h = font.getsize(msg)
    
    x, y = np.random.randint(0, w_img-w), np.random.randint(h, h_img-h)

    draw.text((x,y), msg, font=font, fill="white", anchor = 'ls')
    image.save(save_path)
    
pairs = {'latin':[FONT_LAT, characters_lat],
         'hindi':[FONT_HINDI, characters_hindi],
         'arabic_numerals':[FONT_LAT,characters_numbers]
        }

for mode in pairs:
    if not os.path.exists(mode):
        os.makedirs(mode)
    for path in tqdm(glob.glob('../dataset/baseline/*')):
        name = path.split('/')[-1]
        new_path = "./{mode}/".format(mode = mode) + name
        add_watermark(path, new_path, pairs[mode][0], pairs[mode][1])

!zip -r watermarks.zip  /kaggle/working/

## 1.2 Collecting the activations
### Collecting the activations from the output layers

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

transforms = torchvision.transforms.Compose([
                           torchvision.transforms.ToTensor(),
                           torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                       ])

In [None]:
!pip install timm
import torchvision.models as models
from timm import create_model


def get_model(model_name):
    if model_name == 'resnet18':
        return models.resnet18(pretrained=True).to(device)
    if model_name == 'alexnet':
        return models.alexnet(pretrained=True).to(device)
    if model_name == 'vit_base_patch16_224':
        return create_model('vit_base_patch16_224', pretrained=True).to(device)
    if model_name == 'beit_base_patch16_224':
        return create_model('beit_base_patch16_224', pretrained=True).to(device)
    if model_name == 'inception_v3':
        return models.inception_v3(pretrained=True).to(device)
    if model_name == 'densenet121':
        return models.densenet121(pretrained=True).to(device)
    if model_name == 'densenet201':
        return models.densenet201(pretrained=True).to(device)
    if model_name == 'densenet161':
        return models.densenet161(pretrained=True).to(device)
    if model_name == 'googlenet':
        return models.googlenet(pretrained=True).to(device)
    if model_name == 'vgg11':
        return models.vgg11(pretrained=True).to(device)
    if model_name == 'vgg13':
        return models.vgg13(pretrained=True).to(device)
    if model_name == 'vgg16':
        return models.vgg16(pretrained=True).to(device)
    if model_name == 'vgg19':
        return models.vgg19(pretrained=True).to(device)
    if model_name == 'mobilenet_v2':
        return models.mobilenet_v2(pretrained=True).to(device)
    if model_name == 'shufflenet_v2_x1_0':
        return models.shufflenet_v2_x1_0(pretrained=True).to(device)
    if model_name == 'resnet50':
        return models.resnet50(pretrained=True).to(device)
    if model_name == 'resnet101':
        return models.resnet50(pretrained=True).to(device)
    if model_name == 'resnet152':
        return models.resnet152(pretrained=True).to(device)
    if model_name == 'resnext101_32x8d':
        return models.resnext101_32x8d(pretrained=True).to(device)
    if model_name == 'wide_resnet101_2':
        return models.wide_resnet101_2(pretrained=True).to(device)

model_names = ['resnet18',
               'resnet50',
               'resnet101',
               'resnet152',
               'resnext101_32x8d',
               'wide_resnet101_2',
               'alexnet',
               'vit_base_patch16_224',
               'beit_base_patch16_224',
               'inception_v3',
               'densenet161',
               'mobilenet_v2',
               'shufflenet_v2_x1_0',
               'densenet121',
               'densenet201',
               'googlenet',
               'vgg11',
               'vgg13',
               'vgg16',
               'vgg19'
]

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader,Dataset
import cv2

class_names = ['baseline',
               'chinese',
               'latin',
               'hindi',
               'arabic_numerals'
              ]

class ImageDataset(Dataset):
    def __init__(self,root,transform):
        self.root=root
        self.transform=transform

        self.image_names=glob.glob(self.root + '*.JPEG')
        self.image_names.sort()
   
    #The __len__ function returns the number of samples in our dataset.
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self,index):
        image=cv2.imread(self.image_names[index])
        image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)

        image=self.transform(image)

        return image

In [None]:
with torch.no_grad():
    for j, model_name in enumerate(model_names):
        model = get_model(model_name)
        model.eval()

        logit_scores = torch.zeros([998, 5, 1000])

        for c, class_name in enumerate(class_names):
            dataset = ImageDataset('../dataset/{class_name}/'.format(class_name = class_name),
                                   transforms
                                  )
            testloader = torch.utils.data.DataLoader(dataset,
                                          batch_size=512,
                                          shuffle=False,
                                          num_workers=2)
            counter = 0
            with torch.no_grad():
                for i, x in tqdm(enumerate(testloader)):
                    x = x.float().data.to(device)

                    outputs = model(x)
                    logit_scores[counter:counter + x.shape[0],c,:] = outputs

                    counter += x.shape[0]

        torch.save(logit_scores, '../activations/{name}_wtrmrks.tnsr'.format(name = model_name))

### Collecting the activations from the feature extractors

In [None]:
import torch

model_features_names = ['resnet18',
                        'alexnet',
                        'googlenet',
                        'vit_base_patch16_224',
                        'beit_base_patch16_224',
                        'inception_v3',
                        'densenet161',
                        'mobilenet_v2',
                        'shufflenet_v2_x1_0',
                        'vgg11']

activation = {}

def make_hooks(model_name):
    if model_name == 'resnet18': #512
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
        model.avgpool.register_forward_hook(get_activation('features'))
        return 512
    elif model_name == 'alexnet': #4096
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output
            return hook
        model.classifier[5].register_forward_hook(get_activation('features'))
        return 4096
    elif model_name == 'vit_base_patch16_224': #768
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output
            return hook
        model.fc_norm.register_forward_hook(get_activation('features'))
        return 768
    elif model_name == 'beit_base_patch16_224':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output
            return hook
        model.fc_norm.register_forward_hook(get_activation('features'))
        return 768
    elif model_name == 'inception_v3':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
        model.avgpool.register_forward_hook(get_activation('features'))
        return 2048
    elif model_name == 'densenet161':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
        model.features.register_forward_hook(get_activation('features'))
        return 2208
    elif model_name == 'mobilenet_v2':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
        model.features.register_forward_hook(get_activation('features'))
        return 1280
    elif model_name == 'shufflenet_v2_x1_0':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.mean(axis = [2,3])
            return hook
        model.conv5.register_forward_hook(get_activation('features'))
        return 1024
    elif model_name == 'vgg11':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output
            return hook
        model.classifier[5].register_forward_hook(get_activation('features'))
        return 4096
    elif model_name == 'googlenet':
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output
            return hook
        model.dropout.register_forward_hook(get_activation('features'))
        return 1024
        
with torch.no_grad():
    for j, model_name in enumerate(model_features_names):
        model = get_model(model_name)
        d = make_hooks(model_name)
        
        model.eval()

        logit_scores = torch.zeros([998, 5, d])
        
        for c, class_name in enumerate(class_names):
            dataset = ImageDataset('../dataset/{class_name}/'.format(class_name = class_name),
                                   transforms
                                  )
            testloader = torch.utils.data.DataLoader(dataset,
                                          batch_size=512,
                                          shuffle=False,
                                          num_workers=2)
            
            counter = 0
            with torch.no_grad():
                for i, x in tqdm(enumerate(testloader)):
                    x = x.float().data.to(device)

                    outputs = model(x)
                    logit_scores[counter:counter + x.shape[0],c,:] = activation["features"]

                    counter += x.shape[0]
            
        torch.save(logit_scores, '../activations/{name}_features_wtrmrks.tnsr'.format(name = model_name))