In [10]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from typing import Tuple, Sequence, Callable
import math
from tqdm.notebook import tqdm
from appknn import app_k_nearest, mysample, calculate_margin, create_net, adf

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

In [17]:
def classify_app(appid, labels):
    return labels[labels.apn==appid]['malware_label'].values[0]

def verify_point(appid, net, labels, k, distance):
    n = app_k_nearest(k=k, apps=net, new_app=appid, distance=distance)
    if n[0] == appid:
        print("It found itself... not really a classification")
        
    ap_mal = classify_app(appid, labels)
    pt_mal = classify_app(n[0], labels)

    return ap_mal, pt_mal
    

def generate_net(v, labels, sample_size, problematic_pairs):
    smp = mysample(v, sample_size) #v.sample(sample_size, random_state=42)
    print(f"sample created {smp.shape[0]/v.shape[0]:.2f}")
    if len(problematic_pairs)>0:
        problematic = list(itertools.chain(*problematic_pairs))
        print(f"Removing problematic {len(problematic)} apps")
        smp = smp[~smp.apn.isin(problematic)]

    funcs_smp = smp.groupby(by='apn')['nf'].apply(set)
    margin, problematic_prs = calculate_margin(smp, labels, distance=lambda x,y,z: adf(x,y, funcs_smp))
    gamma = margin / 2.0

    
    problematic = list(itertools.chain(*(problematic_pairs+problematic_prs)))
    print(f"Second removal of problematics {len(problematic)}")
    smp = smp[~smp.apn.isin(problematic)]
    #I guess we don't need new dsitance matrix funcs_smp (only removal)
    
    
    train, test = train_test_split(smp.apn.unique())
    print(f"Calculating net with {gamma}")
    net = create_net(gamma=gamma, apns=train, distance=lambda x,y: adf(x,y, funcs_smp))
    print(f"Net {len(net)} of {len(train)} created")
    
    return smp, train, test, net

In [15]:
v = pd.read_csv('data/functions_encoded.csv')
labels = pd.read_csv('./data/labels_encoded.csv')



In [4]:
import itertools
problematic_pairs = [(20353, 11822), (20353, 5960), (20353, 5279), (20353, 23352), (20353, 4508), (20353, 15342), (20353, 2049), (20353, 15167), (20353, 22414), (20353, 9094), (20353, 25173), (20353, 7987), (20353, 10025), (20353, 5217), (20353, 9950), (20353, 24486), (20353, 17737), (20353, 17091), (20353, 24216), (20353, 23845), (20353, 6845), (20353, 25822), (20353, 4544), (20353, 5104), (20353, 1342), (20353, 16752), (20353, 17521), (20353, 5748), (20353, 21368), (20353, 23385), (20353, 24937), (20353, 10917), (20353, 27580), (20353, 17441), (20353, 16741), (20353, 4207), (20353, 8831), (20353, 22246), (20353, 15939), (20353, 21521), (20353, 14873), (20353, 4419), (20353, 23693), (20353, 12381), (20353, 23648), (20353, 12363), (20353, 22947), (20353, 22142), (20353, 17493), (20353, 13548), (20353, 14005), (20353, 14118), (20353, 17489), (20353, 11314), (20353, 23366), (20353, 24831), (20353, 12600), (20353, 22022), (20353, 17768), (20353, 23425), (20353, 21211), (20353, 3520), (20353, 10499), (20353, 16335), (20353, 26645), (20353, 16786), (20353, 4553), (20353, 15159), (20353, 10735), (20353, 2865), (20353, 440), (20353, 1441), (20353, 6307), (20353, 1503), (20353, 24948), (20353, 7846), (20353, 17565), (20353, 21973), (20353, 21833), (20353, 22007), (20353, 22165), (20353, 16663), (20353, 5525), (20353, 6969), (20353, 13086), (20353, 9158), (20353, 21508), (20353, 22015), (20353, 8245), (20353, 5230), (20353, 24770), (20353, 1785), (20353, 10814), (20353, 7186), (20353, 1638), (20353, 25910), (20353, 23679), (20353, 22521), (20353, 929), (20353, 465), (20353, 154), (20353, 24845), (20353, 5108), (20353, 22169), (20353, 15855), (20353, 22805), (20353, 10557), (20353, 4718), (20353, 23042), (20353, 16932), (20353, 805), (20353, 13492), (20353, 25587), (20353, 22874), (20353, 23334), (20353, 360), (20353, 2081), (20353, 23961), (20353, 22848), (20353, 16955), (20353, 6856), (20353, 2258), (20353, 27284), (20353, 13875)]
problematic = list(itertools.chain(*problematic_pairs))

In [21]:
def check(v, labels, problematic_pairs, sample_size):
    print("Generating net")
    smp, train, test, net = generate_net(v, labels, sample_size, problematic_pairs)
    funcs_smp = smp.groupby(by='apn')['nf'].apply(set)

    print(f"Comparing classifications using net {len(net)} and full dataset ({len(train)})")
    net_hits = 0
    train_hits = 0

    for p in tqdm(test):
        # classify the same point from test set using network and full train dataset
        is_app_malware, net_class = verify_point(p, net, labels, k=2, distance=lambda x, y: adf(x, y, funcs_smp ))
        is_app_malware, full_class = verify_point(p, train, labels, k=2,distance=lambda x, y: adf(x, y, funcs_smp ))
        if is_app_malware == net_class:
            net_hits+=1

        if is_app_malware == full_class:
            train_hits+=1
        
        if full_class!=net_class:
            print(f"Different classification result for {p}")

    print(f"Net succes rate: {net_hits/len(test)}")
    print(f"Full set sucess r: {train_hits/len(test)}")

In [22]:
check(v, labels, problematic_pairs, sample_size=500)

Generating net
sample created 0.03
Removing problematic 248 apps
Split finished: 227851 malicious, 830408 bening, 1058259 overall
Second removal of problematics 416
Calculating net with 0.5
Net 157 of 190 created
Comparing classifications using net 157 and full dataset (190)


HBox(children=(FloatProgress(value=0.0, max=64.0), HTML(value='')))


Net succes rate: 0.90625
Full set sucess r: 0.90625
