# CoLLIE inference demo

This demo allows to test the CoLLIE model with any arbitrary guideline and text input.

In [1]:
# Import dependencies

# Change path to load CoLLIE src
import sys

sys.path.append("../")

from src.model.load_model import load_model_for_inference
from src.tasks.utils_typing import Entity, Event

from dataclasses import dataclass
from typing import List
from IPython.display import display, Markdown
import rich
import black

[2023-08-01 10:44:54,264] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


### Load the model and tokenizer

In [2]:
model, tokenizer = load_model_for_inference(
    weights_path="/gaueko1/hizkuntza-ereduak/LLaMA/lm/huggingface/7B/",
    quantization=4,
    lora_weights_name_or_path="/ikerlariak/osainz006/models/collie/CoLLIE-7b_lora4_flash",
)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



### Define the inference function

The function takes the input from the last executed cell. **Important**: make sure the last executed cell contains the input for the model!!

In [3]:
def inference():
    prompt = black.format_str(_i, mode=black.Mode())
    prompt = prompt + "\n\n# The annotation instances that take place in the text above are listed here\nresult ="

    model_input = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
    model_input["input_ids"] = model_input["input_ids"][:, :-1]

    model_ouput = model.generate(
        input_ids=model_input.input_ids.to(model.device),
        max_new_tokens=512,
        do_sample=False,
        min_new_tokens=0,
        num_beams=1,
    )
    result = tokenizer.batch_decode(model_ouput, skip_special_tokens=True)[0]
    result = result.split("result = ")[-1]

    return result

### Guideline and input definition

The following cell is used to define the input for the model. The input must contain:
  * the guidelines
  * the input sentence

In [158]:
# The following lines describe the task definition
@dataclass
class DatabreachAttack(Event):
    """An DatabreachAttack Event happens when an attacker compromises a system
    to later remove or expose the data, e.g., to sell, publish or make it accessible.
    """

    mention: str  # The text span that triggers the event, such as 'exposed', 'published', 'steal', ...
    attacker: List[str]  # The agent (person or organization) of the attack
    attack_pattern: List[str]  # How the attack is done
    victim: List[str]  # The device, organization, person, product or website victim of the attack
    number_of_victim: List[str]  # The number of victims affected by the attack
    compromised_data: List[str]  # The data being compromised: 'information', 'data', ...
    number_of_data: List[str]  # The amount of compromised data
    damage_amount: List[str]  # The amount of damage done to the victim
    tool: List[str]  # The file, malware or website used to attack
    purpose: List[str]  # The reason or purpose behind the attack
    place: List[str]  # Where the attack occurred
    time: List[str]  # When the attack occurred


@dataclass
class PhisingAttack(Event):
    """A PhisingAttack Event happens when an attacker imitates another entity, in
    an attempt to get a victim to access malicious materials, such as a website or
    attachments.
    """

    mention: str  # The text span that triggers the event, such as 'phising', ...
    attacker: str  # The agent (person or organization) of the attack
    attack_method: List[str]  # How the attack is done
    victim: List[str]  # The device, organization, person, product or website victim of the attack
    trusted_entity: List[str]  #
    damage_amount: List[str]  # The amount of damage done to the victim
    tool: List[str]  # The file, malware or website used to attack
    purpose: List[str]  # What wants to be attacked or stolen
    place: List[str]  # Where the attack occurred
    time: List[str]  # When the attack occurred


@dataclass
class VulnerabilityDiscover(Event):
    """A VulnerabilityDiscover Event happens when a security expert or other entity,
    like a company, finds a software vulnerability."""

    mention: str  # The text span that triggers the event, such as 'found', ...
    cve: List[str]  # The vulnerability identifier: such 'CVE-2018-5003'
    capabilities: List[str]  # The capabilities of the vulnerability
    discoverer: List[str]  # The entity that discovered the vulnerability
    supported_platform: List[str]  # The platforms that support the vulnerability
    vulnerability: List[str]  # The vulnerabilities, such as 'vulnerability', ...
    vulnerable_system: List[str]  # The systems vulnerable to the vulnerability
    system_owner: List[str]  # The owners of the vulnerable system
    system_version: List[str]  # The version of the vulnerable system
    time: List[str]  # When the attack occurred


# The list called result contains the instances for the following events according to the guidelines above:
#    - "purports to be" triggers a PhisingAttack event.
#
text = "The attachment purports to be a flight confirmation or receipt but, of course, it's neither of these things."

In [151]:
# The following lines describe the task definition
@dataclass
class VulnerabilityPatch(Event):
    """A VulnerabiltyPatch Event happens when  software company addresses a known 
    vulnerability by releasing or describing an appropriate update.
    """
    
    mention: str
    """The text span that triggers the event, such as:
        - 'patch', 'fixed', 'addresses', 'implemented',
        'released', ...
    """
    cve: List[str]  # The vulnerability identifier: such 'CVE-2018-5003'
    issues_addressed: List[str] # What did the patch fixed
    supported_platform: List[str]  # The platforms that support the vulnerability
    vulnerability: List[str]  # The vulnerability, such as 'vulnerability'
    vulnerable_system: List[str]  # The affected systems, such as 'infraestructures'
    releaser: List[str] # The entity releasing the patch
    patch: List[str] # What was the patch about
    patch_number: List[str] # Nunber or name of the patch
    system_version: List[str]  # The version of the vulnerable system
    time: List[str]  # When was the patch implemented, the date


# The list called result contains the instances for the following events according to the guidelines above:
#    - "has released" triggers a VulnerabilityPatch event.
#    - "patch" triggers a VulnerabilityPatch event.
#
text = "Microsoft has released an emergency security update to patch below-reported crazy bad remote code execution vulnerability in its Microsoft Malware Protection Engine (MMPE) that affects Windows 7, 8.1, RT and 10 computers, as well as Windows Server 2016 operating systems."

In [147]:
# The following lines describe the task definition
@dataclass
class VulnerabilityPatch(Event):
    
    mention: str
    cve: List[str]
    issues_addressed: List[str]
    supported_platform: List[str]
    vulnerability: List[str]
    vulnerable_system: List[str]
    releaser: List[str]
    patch: List[str]
    patch_number: List[str]
    system_version: List[str]
    time: List[str]


# The list called result contains the instances for the following events according to the guidelines above:
#    - "has released" triggers a VulnerabilityPatch event.
#    - "patch" triggers a VulnerabilityPatch event.
#
text = "Microsoft has released an emergency security update to patch below-reported crazy bad remote code execution vulnerability in its Microsoft Malware Protection Engine (MMPE) that affects Windows 7, 8.1, RT and 10 computers, as well as Windows Server 2016 operating systems."

In [145]:
# The following lines describe the task definition
@dataclass
class DatabreachAttack(Event):
    """An DatabreachAttack Event happens when an attacker compromises a system
    to later remove or expose the data, e.g., to sell, publish or make it accessible.
    """

    mention: str
    """The text span that triggers the event, such as:
        - 'attack', 'expose', 'publish', 'steal', ...
    """


@dataclass
class PhisingAttack(Event):
    """A PhisingAttack Event happens when an attacker imitates another entity, in
    an attempt to get a victim to access malicious materials, such as a website or
    attachments.
    """

    mention: str
    """The text span that triggers the event, such as:
        - 'attack', 'purports to be', 'dupe', ...
        'masquerading as', 'pretending to be', 'scam', ...
        'BEC (Business Email Compromise)'
    """


@dataclass
class RansomAttack(Event):
    """A RansomAttack Event happens when n attacker breaks into a system and
    encrypts data, and will only decrypt the data for a ransom payment.
    """

    mention: str
    """The text span that triggers the event, such as:
        - 'attack', ransomware', 'selling', 'ransom', ...
    """


@dataclass
class VulnerabilityDiscover(Event):
    """A VulnerabilityDiscover Event happens when a security expert or other entity,
    like a company, finds a software vulnerability."""

    mention: str
    """The text span that triggers the event, such as:
        - 'found', 'exploit', 'vulnerability', ...
    """


@dataclass
class VulnerabilityPatch(Event):
    """A VulnerabiltyPatch Event happens when  software company addresses a known 
    vulnerability by releasing or describing an appropriate update.
    """
    
    mention: str
    """The text span that triggers the event, such as:
        - 'patch', 'fixed', 'addresses', 'implemented',
        'released', ...
    """


# This is the text to analyze
text = "Microsoft has released an emergency security update to patch below-reported crazy bad remote code execution vulnerability in its Microsoft Malware Protection Engine (MMPE) that affects Windows 7, 8.1, RT and 10 computers, as well as Windows Server 2016 operating systems."

In [139]:
# The following lines describe the task definition
@dataclass
class DatabreachAttack(Event):
    mention: str


@dataclass
class PhisingAttack(Event):
    mention: str


@dataclass
class RansomAttack(Event):
    mention: str


@dataclass
class VulnerabilityDiscover(Event):
    mention: str
        
@dataclass
class VulnerabilityPatch(Event):
    mention: str


# This is the text to analyze
text = "Microsoft has released an emergency security update to patch below-reported crazy bad remote code execution vulnerability in its Microsoft Malware Protection Engine (MMPE) that affects Windows 7, 8.1, RT and 10 computers, as well as Windows Server 2016 operating systems."

In [152]:
result = inference()
rich.print(result)

In [94]:
# The following lines describe the task definition
@dataclass
class LABEL_1(Entity):
    """Vehicles that have four wheels, for example, automobile."""

    span: str


@dataclass
class LABEL_2(Entity):
    """Vehicles that have two wheels, for example, bicycles."""

    span: str


@dataclass
class LABEL_3(Entity):
    """Vehicles without wheels, for example, submarines"""

    span: str


# This is the text to analyze
text = "The car overtake the motorbike very fast, however, the plane is much faster."

In [95]:
result = eval(inference())
rich.print(result)

In [96]:
result[0].span

'car'