In [1]:
import torchvision
import wasmshield
import wasmshield.training.trainer
import wasmshield.utils
import wasmshield.preprocessing
import joblib
import os
import tqdm
import pytorch_lamb
import lightly
import timm

size = 64
device = 'mps'

  _torch_pytree._register_pytree_node(


In [2]:
import joblib
semantic_X=joblib.load('evaluation_logs/semantic_X')
semantic_train_idx =joblib.load('evaluation_logs/semantic_train_idx')
semantic_test_idx=joblib.load('evaluation_logs/semantic_test_idx')
semantic_y=joblib.load('evaluation_logs/semantic_y')

In [3]:
(
    train_set, 
    test_set, 
    wasm_bench_train_set, 
    wasm_bench_test_set, 
    obfuscated_train_set, 
    obfuscated_test_set, 
    msr_train_set, 
    msr_test_set,
) = wasmshield.utils.load_datasets()

100%|██████████| 14860/14860 [00:03<00:00, 3921.61it/s] 
100%|██████████| 3715/3715 [00:01<00:00, 3284.44it/s]


In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch

from fastai.layers import ConvLayer, NormType

class SelfAttention(nn.Module):
    "Self attention layer for `n_channels`."
    def __init__(self, n_channels):
        super(SelfAttention, self).__init__()
        self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def _conv(self,n_in,n_out):
        return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)

    def forward(self, x):
        #Notation from the paper.
        size = x.size()
        x = x.view(*size[:2],-1)
        f,g,h = self.query(x),self.key(x),self.value(x)
        beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()

In [7]:
import torch, torchsummary, torch.nn as nn


class ResBlock(torch.nn.Module):
    def __init__(self, n, att=False):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n, n, kernel_size = 3, stride = 1, padding = 1,),
            torch.nn.BatchNorm2d(n),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return nn.ReLU(inplace=True)(self.conv(x)+x)
    
class ConvRes(torch.nn.Module):
    def __init__(self, att=False, reduction=None, ll=3):
        super().__init__()

        n = 32
        m = 64
        o = 64
        p = 128
        q = 128

        self.num = 0

        layers_channels = [

            (ll,n),

            (ll,m),
            (ll,o),

            (ll,p),
            (ll,q),


        ]
        
        blocks = []

        last_chan = 3+self.num

        for nb_blocks_per_layer, new_chan in layers_channels:

            if last_chan != new_chan:

                blocks.extend(
                    [
                        nn.Conv2d(last_chan, new_chan, kernel_size = 1, stride = 1,),
                        nn.ReLU(inplace=True),
                        torch.nn.BatchNorm2d(new_chan),
                    ]
                )

            blocks.extend([ResBlock(new_chan, att=att) for _ in range(nb_blocks_per_layer)])
            blocks.append(torch.nn.BatchNorm2d(new_chan))
            blocks.append(nn.MaxPool2d(2,2))
            if att==True:
                blocks.append(SelfAttention(new_chan))

            last_chan = new_chan

        self.reduction = (
            torch.nn.AdaptiveAvgPool2d(1) if reduction == 'avg'
            else (
                torch.nn.AdaptiveMaxPool2d(1) if reduction == 'max'
                else torch.nn.Identity()
            )
        )
            
        self.conv1 = torch.nn.Sequential(

            *blocks,
            self.reduction,
            torch.nn.Flatten(),

        )

    def forward(self, x):
        x = self.conv1(x)
        return x 

torchsummary.summary(ConvRes(att=True, reduction='max'), (3,64,64))

''

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 128]                 --
|    └─Conv2d: 2-1                       [-1, 32, 64, 64]          128
|    └─ReLU: 2-2                         [-1, 32, 64, 64]          --
|    └─BatchNorm2d: 2-3                  [-1, 32, 64, 64]          64
|    └─ResBlock: 2-4                     [-1, 32, 64, 64]          --
|    |    └─Sequential: 3-1              [-1, 32, 64, 64]          9,312
|    └─ResBlock: 2-5                     [-1, 32, 64, 64]          --
|    |    └─Sequential: 3-2              [-1, 32, 64, 64]          9,312
|    └─ResBlock: 2-6                     [-1, 32, 64, 64]          --
|    |    └─Sequential: 3-3              [-1, 32, 64, 64]          9,312
|    └─BatchNorm2d: 2-7                  [-1, 32, 64, 64]          64
|    └─MaxPool2d: 2-8                    [-1, 32, 32, 32]          --
|    └─SelfAttention: 2-9                [-1, 32, 32, 32]          --
|    

''

In [10]:
import torch

device = 'mps'

temperatures = [
    0.15,
]
for reduction in [
    'max'
]:

    for temp in temperatures:

        print('Training ')

        vec_size = 128 if reduction is not None else 512
        hidden_size = 128 if reduction is not None else 256
        out_size = 128

        backbone, model_name = (
            ConvRes(att=False, reduction=reduction, ll=3), 
            f'Emprique60_2_ResBin18_woSA_b64_i{size}_pil_v{vec_size}_r{reduction}',
        )

        criterion, is_for_simclr = lightly.loss.NTXentLoss(
            temperature=temp,
            memory_bank_size=0,
        ), True
        name = model_name + f'_t{criterion.temperature}'.replace('.','')

        model = wasmshield.models.conv.ConvNet(
            img_size=size,
            backbone=backbone,
            vec_size=vec_size,
            hidden_size=hidden_size,
            out_size=out_size,
            use_head=True
        )

        trainable_model = wasmshield.training.trainer.TrainableModel(
            model=model,
            name=name,
            device=device
        )
        optimizer = torch.optim.Adam(trainable_model.model.parameters(), lr=0.0025, weight_decay=0)

        trainer = wasmshield.training.trainer.Trainer(
            trainable_model=trainable_model,
            optimizer=optimizer,
            preprocessor=lambda *args, **kwargs:wasmshield.preprocessing.preprocess_image(*args,**kwargs, use_pil=True, compress=False, size=size),
            training_preprocessor=wasmshield.preprocessing.preprocess_image_for_training,
            formai_dataset=(train_set, test_set),
            obfuscated_formai_dataset=(obfuscated_train_set, obfuscated_test_set),
            wasm_bench_dataset=(wasm_bench_train_set, wasm_bench_test_set),
            msr_dataset=(msr_train_set, msr_test_set),
            semantic_dataset=(semantic_X, semantic_y, semantic_train_idx, semantic_test_idx),
            device=device
        )

        batch_types_levels = [
            (60,[
                wasmshield.training.trainer.BatchType.obfuscated_formai,
                wasmshield.training.trainer.BatchType.mutated_formai,
                wasmshield.training.trainer.BatchType.mutated_optim_levels_formai,
                wasmshield.training.trainer.BatchType.optim_level_formai,
                wasmshield.training.trainer.BatchType.mutated_wasm_bench,
                wasmshield.training.trainer.BatchType.mutated_msr,
                wasmshield.training.trainer.BatchType.semantic_classification,
            ]), 
        ]

        n_reps_per_batch = {
            wasmshield.training.trainer.BatchType.obfuscated_formai : 7,
            wasmshield.training.trainer.BatchType.mutated_formai : 1,
            wasmshield.training.trainer.BatchType.optim_level_formai : 5,
            wasmshield.training.trainer.BatchType.mutated_optim_levels_formai : 4,
            wasmshield.training.trainer.BatchType.mutated_msr : 1,
            wasmshield.training.trainer.BatchType.mutated_wasm_bench : 1,
            wasmshield.training.trainer.BatchType.semantic_classification : 0,
        }

        batch_types = []
        criterion_semantic = nn.CrossEntropyLoss()
        batch_size = 64
        n_batches_per_epoch = 5

        batch_size_semantic=256
        max_batch_size = 64

        print('Training', f'{trainable_model.name=}')

        for nb_epochs,bt in batch_types_levels:
            
            batch_types = bt
            print('')
            print('='*10)
            print(bt)
            print('='*10)

            trainer.train(
                startepoch=trainable_model.epoch,
                nbepochs=trainable_model.epoch+nb_epochs,
                batch_size=batch_size,
                max_batch_size=max_batch_size,
                batch_types=list((batch_types)),
                n_reps_per_batch=n_reps_per_batch,
                n_batches_per_epoch=n_batches_per_epoch,
                criterion=criterion,
                do_shuffle=True,
                semantic_mlp=torch.nn.Identity(),
                criterion_semantic=criterion_semantic,
                batch_size_semantic=batch_size_semantic,
            )
        
        print('Saving', f"{trainable_model.name=}")
        trainable_model.save()


Training 
Training trainable_model.name='Emprique60_2_ResBin18_woSA_b64_i64_pil_v128_rmax_t015'

[<BatchType.obfuscated_formai: 0>, <BatchType.mutated_formai: 3>, <BatchType.mutated_optim_levels_formai: 2>, <BatchType.optim_level_formai: 1>, <BatchType.mutated_wasm_bench: 4>, <BatchType.mutated_msr: 5>, <BatchType.semantic_classification: 6>]
Stepping backward.
| epoch = 1 | batch = 1 | loss = 4.29028 |

Stepping backward.
| epoch = 1 | batch = 2 | loss = 3.27331 |

Stepping backward.
| epoch = 1 | batch = 3 | loss = 3.03764 |

Stepping backward.
| epoch = 1 | batch = 4 | loss = 2.84464 |

Stepping backward.
| epoch = 1 | batch = 5 | loss = 2.78920 |


 Test_loss= 3.98969140506926 Test_type= BatchType.obfuscated_formai

 Test_loss= 3.1925763130187987 Test_type= BatchType.optim_level_formai

 Test_loss= 3.3112929264704385 Test_type= BatchType.mutated_optim_levels_formai

 Test_loss= 2.854667584101359 Test_type= BatchType.mutated_formai

 Test_loss= 2.5980478127797446 Test_type= BatchTyp