In [1]:
import os
import json
import torch
torch.set_grad_enabled(False)
import numpy as np
import sys
sys.path.append(r'..\..\..')
from nebula.models.attention import TransformerEncoderChunks
from nebula import JSONTokenizerBPE, PEDynamicFeatureExtractor
from nebula.misc import get_path, clear_cuda_cache, set_random_seed
from nebula.constants import *
from bertviz import model_view, head_view
from collections import Counter

SCRIPT_PATH = get_path(type="notebook")
ROOT = os.path.join(SCRIPT_PATH, r"..\..\..")

In [2]:
def get_attn(x, model):
    attentions = []
    x = model.pos_encoder(model.encoder(x))
    for layer in model.transformer_encoder.layers:
        this_layer_attn = layer.self_attn(x,x,x, average_attn_weights=False)[1]
        attentions.append(this_layer_attn.cpu())
        x = layer(x)
    return attentions

def viz(start, end, attentions, exampleTokenized, threshold=0.005, layer=1, heads=[]):
    m = 1/threshold
    tokens = exampleTokenized[start:end]
    att_crop = [x[:, :, start:end, start:end]*m for x in attentions]
    head_view(att_crop, tokens, prettify_tokens=False, heads=heads, layer=layer)

def report_where_attends(seq_1, seq_2, token_location, tokenized_input):
    # TODO: somewhat broken?
    msg = ""
    token = tokenized_input[token_location]
    where_attends_seq_2 = np.where(seq_1 == token_location)[0]
    if len(where_attends_seq_2) > 0:
        counterpart_tokens = list(set([tokenized_input[x] for x in seq_2[where_attends_seq_2].tolist()]))
        if token in counterpart_tokens:
            counterpart_tokens.remove(token)
        if counterpart_tokens: # TODO: this part reports wrong results
            msg = f"\t  It has connections with tokens at following locations and values:\n"
            msg += f"\t\tLocated: {where_attends_seq_2}\n"
            msg += f"\t\tValues: {counterpart_tokens}\n"
    return msg

def analyze_attentions(
        attentions,
        tokenized_input,
        threshold=0.005,
        diff=20,
        most_common=5,
        types=["proximity", "frequency"],
        limit=20,
        verbose=True,
        token_whitelist=[],
):
    counter = 0
    print(f"\n[!] Analyzing attentions based on {types}... ")
    if token_whitelist:
        print(f"\n[!] Whitelisted tokens: {token_whitelist}... Everything else will be ignored!")
    for layer_nr, layer in enumerate(attentions):
        idxs = np.where(layer > threshold)
        print(f"\n[!] Total {len(idxs[0])} tokens has strong activations at layer {layer_nr} with threshold: {threshold}!")
        if len(idxs[0]) > 0:
            heads = idxs[1].tolist()
            attn_seq_1 = idxs[2]
            attn_seq_2 = idxs[3]

            if "proximity" in types:
                # PROXIMITY CHECK
                print(f"[*] Performing a token attention proximity check with token difference: {diff}... ")
                for j, (a1, a2) in enumerate(zip(attn_seq_1, attn_seq_2)):
                    # check if difference between a1 and a2 is below threshold
                    token_at_a1 = tokenized_input[a1]
                    token_at_a2 = tokenized_input[a2]
                    if token_at_a1 == token_at_a2:
                        continue
                    if token_whitelist and \
                        (token_at_a1.strip("▁") not in token_whitelist \
                         and token_at_a2.strip("▁") not in token_whitelist):
                            continue
                    if abs(a1 - a2) > diff:
                        if verbose:
                            print(f"\tThese two tokens has strong activations but are far (diff: {abs(a1 - a2)}):\n\t\t{token_at_a1}\n\t\t{token_at_a2}")
                    elif abs(a1 - a2) <= diff:
                        counter += 1
                        if counter > limit:
                            return
                        which_heads = heads[j]
                        print("\tThese two close sequence tokens has strong activations:", a1, a2, f"at layer {layer_nr} head", which_heads)
                        print(f"\t\tToken at location {a1} -> {token_at_a1}")
                        print(f"\t\tToken at location {a2} -> {token_at_a2}")
                        aa = sorted([a1, a2])
                        a1, a2 = aa[0]-5, aa[1]+5
                        a1 = 0 if a1 < 0 else a1
                        a2 = len(tokenized_input) if a2 > len(tokenized_input) else a2
                        vizString = f"viz({a1}, {a2}, attentions, tokenized_input, layer={layer_nr}, heads=[{which_heads}])"
                        print("Calling:", vizString)
                        viz(a1, a2, attentions, tokenized_input, layer=layer_nr, heads=[which_heads])
            if "count" in types or "frequency" in types:
                # COUNTER CHECK
                print("[*] Performing a token frequency check ... ")
                c = attn_seq_2.tolist()
                c.extend((attn_seq_1.tolist()))
                c = Counter(c).most_common(most_common)        
                for token_location, token_frequency in c:
                    appear_in_heads = [heads[i] for i in np.where(attn_seq_1 == token_location)[0].tolist()]
                    appear_in_heads.extend([heads[i] for i in np.where(attn_seq_2 == token_location)[0].tolist()])
                    appear_in_heads = list(set(appear_in_heads))

                    token_value = tokenized_input[token_location]
                    msg1 = report_where_attends(attn_seq_1, attn_seq_2, token_location, tokenized_input)
                    msg2 = report_where_attends(attn_seq_2, attn_seq_2, token_location, tokenized_input)
                    if msg1 or msg2:
                        msg = f"\n\t> Token '{token_value}' location in sequence: {token_location}\n"
                        msg += f"\t  It appears {token_frequency} times, in layer {layer_nr}, heads: {appear_in_heads}\n"
                        msg += msg1 + msg2
                        print(msg)
                        
                        counter += 1
                        if counter > limit:
                            return

def get_attention_report(exampleFile, model, maxLen=512, random_seed=42, diff=20):
    with open(exampleFile) as f:
        example = json.load(f)

    extractor = PEDynamicFeatureExtractor()
    exampleProcessed = extractor.filter_and_normalize_report(example[0])
    
    tokenizer = JSONTokenizerBPE(
        vocab_size=50001,
        seq_len=maxLen,
        model_path=os.path.join(ROOT, r"nebula\objects\speakeasy_BPE_50000_sentencepiece.model")
    )

    exampleTokenized = tokenizer.tokenize(exampleProcessed)[0]
    exampleEncoded = tokenizer.encode(exampleProcessed)[0]
    x = torch.Tensor(exampleEncoded).long().reshape(1,-1).cuda()

    logit_bad = model(x[:, :maxLen]).cpu()
    prob_bad = torch.sigmoid(logit_bad)
    print("[!] Model scores on this sample -- logit:", logit_bad[0][0].numpy(), " and probability:", prob_bad[0][0].numpy())

    clear_cuda_cache()
    set_random_seed(random_seed)
    attentions = get_attn(x, model)

    # analyze_attentions(attentions, exampleTokenized, types=["proximity"], threshold=0.004, diff=diff)
    # analyze_attentions(attentions, exampleTokenized, types=["frequency"], most_common=10)
    return attentions, exampleTokenized

In [3]:
model_file = r"nebula\objects\speakeasy_BPE_50000_torch.model"
state_dict = torch.load(os.path.join(ROOT, model_file))
# in state_dict drop weights and bias of preTrainLayers
state_dict = {k: v for k, v in state_dict.items() if "pretrain_layers" not in k}

model_config = {
    "vocab_size": 50001,  # size of vocabulary
    "maxlen": 512,  # maximum length of the input sequence
    "dModel": 64,  # embedding & transformer dimension
    "nHeads": 8,  # number of heads in nn.MultiheadAttention
    "dHidden": 256,  # dimension of the feedforward network model in nn.TransformerEncoder
    "nLayers": 2,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    "numClasses": 1, # binary classification
    "hiddenNeurons": [64],
    "dropout": 0.3
}
model = TransformerEncoderChunks(**model_config)
model.load_state_dict(state_dict)
_ = model.cuda()

In [4]:
exampleFile = os.path.join("..", "malware", "0009064322cdc719a82317553b805cbbc64230a9212d3b8aad1ea1b78d3bf10a.json")
# https://www.virustotal.com/gui/file/0009064322cdc719a82317553b805cbbc64230a9212d3b8aad1ea1b78d3bf10a
attentions, tokenized_input = get_attention_report(exampleFile, model, maxLen=512, random_seed=1492, diff=10)



[!] Model scores on this sample -- logit: 9.93311  and probability: 0.9999515


# Getting attention visualizations

In [13]:
analyze_attentions(
    attentions,
    tokenized_input,
    types=["proximity"],
    threshold=0.02,
    diff=30,
    limit=3,
)


[!] Analyzing attentions based on ['proximity']... 

[!] Total 0 tokens has strong activations at layer 0 with threshold: 0.02!

[!] Total 24 tokens has strong activations at layer 1 with threshold: 0.02!
[*] Performing a token attention proximity check with token difference: 30... 
	These two tokens has strong activations but are far (diff: 157):
		▁0x1
		▁0x7850
	These two tokens has strong activations but are far (diff: 236):
		▁kernel32
		info
	These two tokens has strong activations but are far (diff: 36):
		▁file
		▁windows
	These two close sequence tokens has strong activations: 33 36 at layer 1 head 3
		Token at location 33 -> ▁ini
		Token at location 36 -> ▁windows
Calling: viz(28, 41, attentions, tokenized_input, layer=1, heads=[3])


<IPython.core.display.Javascript object>

	These two tokens has strong activations but are far (diff: 37):
		▁0x7e
		▁windows
	These two tokens has strong activations but are far (diff: 69):
		▁getprocaddress
		▁windows
	These two tokens has strong activations but are far (diff: 74):
		▁getprocaddress
		▁windows
	These two tokens has strong activations but are far (diff: 115):
		▁flsgetvalue
		▁decodepointer
	These two tokens has strong activations but are far (diff: 104):
		▁kernel32
		▁windows
	These two tokens has strong activations but are far (diff: 147):
		▁0x4244d0
		▁windows
	These two tokens has strong activations but are far (diff: 180):
		▁kernel32
		▁windows
	These two close sequence tokens has strong activations: 216 227 at layer 1 head 3
		Token at location 216 -> ▁kernel32
		Token at location 227 -> ▁decodepointer
Calling: viz(211, 232, attentions, tokenized_input, layer=1, heads=[3])


<IPython.core.display.Javascript object>

	These two tokens has strong activations but are far (diff: 283):
		▁0xf002
		▁windows
	These two tokens has strong activations but are far (diff: 308):
		▁0x20
		▁windows
	These two tokens has strong activations but are far (diff: 315):
		▁kernel32
		▁windows
	These two tokens has strong activations but are far (diff: 334):
		▁000
		▁windows
	These two tokens has strong activations but are far (diff: 343):
		3b80
		▁windows
	These two tokens has strong activations but are far (diff: 352):
		1e
		▁windows
	These two tokens has strong activations but are far (diff: 385):
		▁encodepointer
		▁windows
	These two tokens has strong activations but are far (diff: 387):
		▁0x41573e
		▁windows
	These two tokens has strong activations but are far (diff: 464):
		▁0x220
		▁windows
	These two tokens has strong activations but are far (diff: 65):
		▁heapalloc
		vironment


In [14]:
analyze_attentions(
    attentions,
    tokenized_input,
    types=["proximity"],
    threshold=0.005,
    diff=30,
    limit=3,
    token_whitelist=["gettickcount"]
)


[!] Analyzing attentions based on ['proximity']... 

[!] Whitelisted tokens: ['gettickcount']... Everything else will be ignored!

[!] Total 56123 tokens has strong activations at layer 0 with threshold: 0.005!
[*] Performing a token attention proximity check with token difference: 30... 
	These two tokens has strong activations but are far (diff: 68):
		▁gettickcount
		info
	These two tokens has strong activations but are far (diff: 47):
		▁gettickcount
		▁temp
	These two tokens has strong activations but are far (diff: 42):
		▁gettickcount
		▁open
	These two close sequence tokens has strong activations: 76 95 at layer 0 head 0
		Token at location 76 -> ▁gettickcount
		Token at location 95 -> ▁0x0
Calling: viz(71, 100, attentions, tokenized_input, layer=0, heads=[0])


<IPython.core.display.Javascript object>

	These two close sequence tokens has strong activations: 76 101 at layer 0 head 0
		Token at location 76 -> ▁gettickcount
		Token at location 101 -> ▁kernel32
Calling: viz(71, 106, attentions, tokenized_input, layer=0, heads=[0])


<IPython.core.display.Javascript object>

	These two tokens has strong activations but are far (diff: 52):
		▁gettickcount
		▁tlssetvalue
	These two tokens has strong activations but are far (diff: 136):
		▁gettickcount
		▁initializecriticalsectionandspincount
	These two tokens has strong activations but are far (diff: 137):
		▁gettickcount
		▁0x424560
	These two tokens has strong activations but are far (diff: 174):
		▁gettickcount
		▁getmodulehandlew
	These two tokens has strong activations but are far (diff: 273):
		▁gettickcount
		▁getenvironmentstringsw
	These two tokens has strong activations but are far (diff: 302):
		▁gettickcount
		755
	These two tokens has strong activations but are far (diff: 337):
		▁gettickcount
		▁0x7bf0
	These two tokens has strong activations but are far (diff: 357):
		▁gettickcount
		▁0x7d91
	These two tokens has strong activations but are far (diff: 410):
		▁gettickcount
		▁flsgetvalue
	These two tokens has strong activations but are far (diff: 68):
		▁gettickcount
		info
	These two tokens ha

<IPython.core.display.Javascript object>

## Step-by-step

In [15]:
with open(exampleFile) as f:
    example = json.load(f)
print(example[0].keys())
print(example[0]['file_access'][-2:])

extractor = PEDynamicFeatureExtractor()
exampleProcessed = extractor.filter_and_normalize_report(example[0])
exampleProcessed['file_access']

dict_keys(['ep_type', 'start_addr', 'ep_args', 'apihash', 'apis', 'ret_val', 'error', 'file_access', 'dynamic_code_segments', 'dropped_files'])
[{'event': 'create', 'path': 'Cd.exedows\\temp\\HGDraw.dll', 'open_flags': ['CREATE_ALWAYS'], 'access_flags': ['GENERIC_WRITE']}, {'event': 'write', 'path': 'Cd.exedows\\temp\\HGDraw.dll', 'data': 'TVqQAAMAAAAEAAAA//8AALgAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA8AAAAA4fug4AtAnNIbgBTM0hVGhpcyBwcm9ncmFtIGNhbm5vdCBiZSBydW4gaW4gRE9TIG1vZGUuDQ0KJAAAAAAAAABJLe9jDUyBMA1MgTANTIEwYjoqMCBMgTBiOh8wG0yBMGI6KzCOTIEwBDQSMAJMgTANTIAwqEyBMGI6LjAfTIEwYjobMAxMgTBiOhwwDEyBMFJpY2gNTIEwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAUEUAAEwBAwBWBhxSAAAAAAAAAADgAAIBCwEKAACQAQAAhgQAAAAAAFrFAAAAEAAAAKABAAAAQAAAEAAAAAIAAAUAAQAAAAAABQABAAAAAAAA4AYAAAQAANnPBQACAECBAAAQAAAQAAAAABAAABAAAAAAAAAQAAAAAAAAAAAAAADcFAIAjAAAAABgBgC6YgAAAAAAAAAAAAAAAAAAAAAAAABABgB0DwAAQKIBABwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAudGV4dA

[{'event': 'create', 'path': '<drive>\\windows\\temp\\golfinfo.ini'},
 {'event': 'create', 'path': '<drive>\\windows\\temp\\golfset.ini'},
 {'event': 'create', 'path': '<drive>\\windows\\temp\\golfinfo.ini'},
 {'event': 'write', 'path': '<drive>\\windows\\temp\\golfinfo.ini'},
 {'event': 'open', 'path': '<drive>\\windows\\system32\\<sha256>'},
 {'event': 'read', 'path': '<drive>\\windows\\system32\\<sha256>'},
 {'event': 'create', 'path': 'cd.exedows\\temp\\hgdraw.dll'},
 {'event': 'write', 'path': 'cd.exedows\\temp\\hgdraw.dll'}]

In [18]:
tokenizer = JSONTokenizerBPE(
    vocab_size=50001,
    seq_len=512,
    model_path=os.path.join(ROOT, r"nebula\objects\speakeasy_BPE_50000_sentencepiece.model")
)
exampleTokenized = tokenizer.tokenize(exampleProcessed)[0]
exampleEncoded = tokenizer.encode(exampleProcessed)[0]

print(exampleTokenized[0:20])
print(exampleEncoded[0:20])
x = torch.Tensor(exampleEncoded).long().reshape(1,-1).cuda()



['▁file', '▁access', '▁create', '▁drive', '▁windows', '▁temp', '▁g', 'olf', 'info', '▁ini', '▁create', '▁drive', '▁windows', '▁temp', '▁g', 'olf', 'set', '▁ini', '▁create', '▁drive']
[1313 1036  278  971 6394 2053  312 1054  204 1091  278  971 6394 2053
  312 1054  949 1091  278  971]


In [19]:
logit_bad = model(x[:, :512]).cpu()
prob_bad = torch.sigmoid(logit_bad)
print("Model scores on this sample -- logit:", logit_bad[0][0].numpy(), " and probability:", prob_bad[0][0].numpy())

Model scores on this sample -- logit: 7.1716046  and probability: 0.99923253


### Getting attention activations

In [20]:
clear_cuda_cache()
set_random_seed(42)
attentions = get_attn(x, model)

In [21]:
analyze_attentions(attentions, exampleTokenized, types=["count"])


[!] Analyzing attentions based on ['count']... 

[!] Total 56935 tokens has strong activations at layer 0 with threshold: 0.005!
[*] Performing a token frequency check ... 

	> Token '▁0xf001' location in sequence: 301
	  It appears 1514 times, in layer 0, heads: [0, 1, 2, 3, 4, 5, 6, 7]
	  It has connections with tokens at following locations and values:
		Located: [ 3298  3299  3300 12315 12316 12317 12318 12319 12320 12321 12322 12323
 12324 12325 12326 12327 12328 12329 12330 12331 12332 12333 12334 12335
 12336 12337 21717 21718 21719 21720 21721 28096 28097 28098 32430 32431
 38879 38880 38881 38882 38883 38884 38885 38886 38887 38888 38889 38890
 38891 38892 38893 46259 46260 46261 46262 46263 53502 53503 53504 53505
 53506 53507 53508 53509 53510]
		Values: ['b78', '▁getstdhandle', '2cdc', '▁kernel32', '755', '▁getacp', '▁0x77000000', '▁0xf', '▁0x1', '▁tlssetvalue', '▁0x41573d', '▁0x4244e8', '▁gettickcount', '▁initializecriticalsectionandspincount', '▁h', '▁0x4244b8', '▁getpro