In [1]:
%load_ext autoreload
%autoreload 2

# Imports
import json
from mitreattack.stix20 import MitreAttackData
import tensorflow as tf
import recommender
import random
import math
import importlib
import pandas as pd

tf.config.run_functions_eagerly(True)

assert tf.executing_eagerly()

importlib.reload(recommender)

<module 'recommender' from '/Users/mjturner/code/technique-inference-engine/models/recommender/__init__.py'>

In [2]:
def get_mitre_technique_ids(stix_filepath: str) -> frozenset[str]:
    """Gets all MITRE technique ids."""
    mitre_attack_data = MitreAttackData(stix_filepath)
    techniques = mitre_attack_data.get_techniques(remove_revoked_deprecated=True)

    all_technique_ids = set()

    for technique in techniques:
        external_references = technique.get("external_references")
        mitre_references = tuple(filter(lambda external_reference: external_reference.get("source_name") == "mitre-attack", external_references))
        assert len(mitre_references) == 1
        mitre_technique_id = mitre_references[0]["external_id"]
        all_technique_ids.add(mitre_technique_id)

    return frozenset(all_technique_ids)

def get_campaign_techniques(filepath: str) -> tuple[frozenset[str]]:
    """Gets a set of MITRE technique ids present in each campaign."""

    with open(filepath) as f:
        data = json.load(f)

    campaigns = data["bags_of_techniques"]

    ret = []

    for campaign in campaigns:

        techniques = campaign["mitre_techniques"]
        ret.append(frozenset(techniques.keys()))

    return ret

def train_test_split(indices: list, values: list, test_ratio: float=0.1) -> tuple:
    n = len(indices)
    assert len(values) == n

    indices_for_test_set = frozenset(random.sample(range(n), k=math.floor(test_ratio * n)))

    train_indices = []
    test_indices = []
    train_values = []
    test_values = []

    for i in range(n):
        if i in indices_for_test_set:
            test_indices.append(indices[i])
            test_values.append(values[i])
        else:
            train_indices.append(indices[i])
            train_values.append(values[i])

    return train_indices, train_values, test_indices, test_values


In [3]:
def main():
    # want matrix of campaigns on horizontal, techniques on vertical
    all_mitre_technique_ids = tuple(get_mitre_technique_ids("../enterprise-attack.json"))
    mitre_technique_ids_to_index = {all_mitre_technique_ids[i]: i for i in range(len(all_mitre_technique_ids))}

    campaigns = get_campaign_techniques("../data/combined_dataset_full_frequency.json")

    indices = []
    values = []

    # for each campaign, make a vector, filling in each present technique with a 1
    for i in range(len(campaigns)):

        campaign = campaigns[i]

        for mitre_technique_id in campaign:
            if mitre_technique_id in mitre_technique_ids_to_index:
                # campaign id, technique id
                index = [i, mitre_technique_ids_to_index[mitre_technique_id]]

                indices.append(index)
                values.append(1)


    train_indices, train_values, test_indices, test_values = train_test_split(indices, values)

    training_data = tf.SparseTensor(
        indices=train_indices,
        values=train_values,
        dense_shape=(len(campaigns), len(all_mitre_technique_ids))
    )
    test_data = tf.SparseTensor(
        indices=test_indices,
        values=test_values,
        dense_shape=(len(campaigns), len(all_mitre_technique_ids))
    )

    # train
    model = recommender.FactorizationRecommender(m=len(campaigns), n=len(all_mitre_technique_ids), k=10)
    model.fit(training_data, num_iterations=1000, learning_rate=10.)

    evaluation = model.evaluate(test_data)
    print("MSE", evaluation)

    predictions = model.predict()

    predictions_dataframe = pd.DataFrame(predictions, columns=all_mitre_technique_ids)

    print(predictions_dataframe)


In [4]:
main()

MSE 0.34170407
        T1578  T1547.014     T1030     T1112  T1550.003     T1049     T1092  \
0    0.784831   0.377630  0.101901  1.001325  -0.158771  0.694602  0.411023   
1   -0.763482   0.224698  1.636720  1.009761   1.009708  0.661873  0.355017   
2    0.544853  -0.018921  0.331601  1.006808   0.202223  1.563948  0.297718   
3    0.902573   0.454326  0.440447  0.998493  -1.109977  0.140948  0.042689   
4   -0.050296   0.301442  0.311153  1.002875   1.002803  0.711632  0.073821   
..        ...        ...       ...       ...        ...       ...       ...   
186  0.437564  -0.162894  0.365576  0.994123  -0.190575  0.995859  0.380639   
187  1.199729   0.407794 -0.073012  0.963695  -0.157554  0.741034  0.225907   
188  0.499829   0.735397  0.651445  1.002326   0.695246  0.366094  0.020164   
189  0.358869   0.484445  0.639233  1.018354   0.295820  1.279500  0.540394   
190  1.458898   0.310654 -0.401749  0.997384  -0.898985  1.000898  0.677263   

     T1505.004  T1218.014  T1564.005