# Overall pipeline
1. Website element detection
2. Credential classifier
3. Phishpedia logo identification
4. Layout matching model

<img src="../example.png" style="width:2000px;height:350px"/>

In [1]:
import os
os.chdir('..')

In [2]:
from detectron2_1.datasets import WebMapper
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import funcy
from IPython.display import clear_output
from detectron2.utils.visualizer import Visualizer
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from pycocotools import cocoeval, coco
from detectron2.data import build_detection_test_loader, MetadataCatalog, DatasetCatalog
import numpy as np
import tldextract
import pickle
import torch
import torch.nn.functional as F
import time

In [3]:
from credential_classifier.bit_pytorch.models import FCMaxPool
from credential_classifier.bit_pytorch.grid_divider import read_img_reverse

In [4]:
from phishpedia.models import KNOWN_MODELS
from phishpedia.utils import brand_converter
from phishpedia.inference import siamese_inference, pred_siamese
from phishpedia.utils import brand_converter

In [5]:
from layout_matcher.layout_matcher_knn import bipartite_web
from layout_matcher.misc import load_yaml

In [15]:
from tqdm import tqdm

In [6]:

def element_config(rcnn_weights_path, rcnn_cfg_path):
    
    # merge configuration
    cfg = get_cfg()
    cfg.merge_from_file(rcnn_cfg_path)
    cfg.MODEL.WEIGHTS = rcnn_weights_path
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3 # lower this threshold to report more boxes
    
    model = DefaultPredictor(cfg)
    return cfg, model


def element_recognition(img, model):
    
    if not isinstance(img, np.ndarray):
        img = cv2.imread(img)
    else:
        img = img
        
    pred = model(img)
    pred_i = pred["instances"].to("cpu")
    pred_classes = pred_i.pred_classes # Boxes types
    pred_boxes = pred_i.pred_boxes.tensor # Boxes coords
    pred_scores = pred_i.scores # Boxes prediction scores

    return pred_classes, pred_boxes, pred_scores

In [7]:
def credential_config(checkpoint):
    
    # load weights
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = FCMaxPool()
    checkpoint = torch.load(checkpoint, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    model.to(device)
    model.eval()
    return model

def credential_classifier(img, coords, types, model):
    
    # process it into grid_array
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    grid_arr = read_img_reverse(img, coords, types)
    assert grid_arr.shape == (9, 10, 10)
    
    with torch.no_grad():
        pred = model(grid_arr.type(torch.float).to(device))
        pred = F.softmax(pred, dim=-1).argmax(dim=-1).item() # 'credential': 0, 'noncredential': 1
    return pred

In [8]:
def phishpedia_config(num_classes, weights_path, targetlist_path, grayscale=False):
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Initialize model
    model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True)
    # Load weights
    model.load_state_dict(torch.load(weights_path, map_location='cpu'))
    model.to(device)
    model.eval()

    '''Prediction for targetlists'''
    logo_feat_list = []
    file_name_list = []
    
    for target in os.listdir(targetlist_path):
        if target.startswith('.'): # skip hidden files
            continue
        for logo_path in os.listdir(os.path.join(targetlist_path, target)):
            if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg'):
                if logo_path.startswith('loginpage') and not logo_path.startswith('homepage'): # skip homepage/loginpage
                    continue
                logo_feat_list.append(pred_siamese(img=os.path.join(targetlist_path, target, logo_path), 
                                                   model=model, grayscale=grayscale))
                file_name_list.append(str(os.path.join(targetlist_path, target, logo_path)))
        
    return model, np.asarray(logo_feat_list), np.asarray(file_name_list)
        

def phishpedia_classifier(pred_classes, pred_boxes, 
                          domain_map_path,
                          model, logo_feat_list, file_name_list, shot_path, 
                          url, 
                          ts):
    
    # targetlist domain list
    with open(domain_map_path, 'rb') as handle:
        domain_map = pickle.load(handle)
        
    # look at boxes for logo class only
    logo_boxes = pred_boxes[pred_classes==0] 
    
    # run logo matcher
    pred_target = None
    if len(logo_boxes) > 0:
        # siamese prediction for logo box
        for i, coord in enumerate(logo_boxes):
            min_x, min_y, max_x, max_y = coord
            bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
            target_this, domain_this = siamese_inference(model, domain_map, 
                                                         logo_feat_list, file_name_list,
                                                         shot_path, bbox, t_s=ts, grayscale=False)
            # domain matcher to avoid FP
            if not target_this is None and tldextract.extract(url).domain not in domain_this: 
                pred_target = target_this 
                break # break if target is matched
    
    return brand_converter(pred_target)
        

In [9]:
def layout_config(cfg_dir, ref_dir, matched_brand):
    
    '''get layout reference list'''
    
    # load cfg
    cfg = load_yaml(cfg_dir)
    
    assert matched_brand in ['Amazon', 'Facebook', 'Google', 'Instagram', 'LinkedIn Corporation', 'ms_skype', 'Twitter, Inc.']
    
    #TODO: save pred coords or not? 
    gt_coords_arr = [] # save ref layout coords
    gt_files_arr = [] # save ref layout filename
    gt_shot_size_arr = [] # save ref layout screenshot size
    
    for template in os.listdir(os.path.join(ref_dir, matched_brand)):
        if template.startswith('.'): # skip hidden file
            continue
        img = cv2.imread(os.path.join(ref_dir, matched_brand, template))
        _, pred_boxes, _ = element_recognition(img, ele_model)
        pred_boxes = pred_boxes.numpy()
        gt_coords_arr.append(pred_boxes)
        gt_files_arr.append(os.path.join(ref_dir, matched_brand, template))
        gt_shot_size_arr.append(img.shape)
        
    return cfg, gt_coords_arr, gt_files_arr, gt_shot_size_arr
        
def layout_matcher(pred_boxes, img, 
                   gt_coords_arr, gt_files_arr, gt_shot_size_arr,
                   cfg):
    
    pred_boxes = pred_boxes.numpy() if isinstance(pred_boxes, torch.Tensor) else pred_boxes
    img = cv2.imread(img) if not isinstance(img, np.ndarray) else img
    shot_size = img.shape
    
    # If the number of reported boxes is less or equal to one, no point of continue
    if len(pred_boxes) <= 1:
        return 0, None
        
    # set initial similarity = 0
    max_s = 0
    max_site = None
    for j, gt_c in enumerate(gt_coords_arr):
        similarity, sim_mat, _, _, _, _,_, _ = \
                    bipartite_web(gt_c, pred_boxes, gt_shot_size_arr[j], shot_size, cfg)
        
        if similarity >= max_s:
            max_s = similarity
            max_site = gt_files_arr[j]

    return max_s, max_site

## Main function

In [10]:
ele_cfg, ele_model = element_config(rcnn_weights_path = 'output/website/model_final.pth', 
                            rcnn_cfg_path='configs/faster_rcnn_web.yaml')


In [12]:
cls_model = credential_config(checkpoint='credential_classifier/FCMax_0.05.pth.tar')

In [13]:
pedia_model, logo_feat_list, file_name_list = phishpedia_config(num_classes=180, 
                                                                weights_path='phishpedia/resnetv2_rgb.pth',
                                                                targetlist_path='phishpedia/expand_targetlist/')

  "Palette images with Transparency expressed in bytes should be "


- Expand domain map

In [18]:
with open(domain_map_path, 'rb') as handle:
    domain_map = pickle.load(handle)

In [None]:
# for target in os.listdir('phishpedia/expand_targetlist/'):
#     if brand_converter(target) not in domain_map.keys():
#         print(brand_converter(target))

In [25]:
domain_map['Lloyds TSB Group'] = ['lloydsbankinggroup', 'lloydsbank']

class DomainDict():
    
    def __init__(self, domain_map):
        self.domain_map = domain_map
        
    def _assign_value(self, key, value):
        self.domain_map[key] = [value]
    
    def main(self):
        key_list = ['barclays', 'VKontakte', 'lcl', 'dkb', 'la_banque_postale', 'La Poste', 'EMS',
                    'bcp', 'barclaycard', 'smiles', 'gov_uk', 'cryptobridge', 'ieee', 'strato',
                    'adidas', 'fsnb', 'hfe', 'cetelem', 'Commonwealth Bank of Australia', 'zoominfo',
                    'file_transfer', 'bradesco', 'postbank', 'Airbnb, Inc.', 'SMBC', 'snapchat', 
                    'docmagic', 'Halifax Bank of Scotland Plc', 'GMX Mail', 'sicil_shop', 'cathay_bank',
                    'otrs', 'mdpd', 'shoptet', 'tech_target', 'Yandex', 'summit_bank', 'Capitec Bank Limited',
                    'Rakuten', 'ziggo', 'Magalu', 'wp60', 'Rabobank Nederland', 'latam', 'capital_one',
                    'db', 'qnb', 'momentum_office_design', 'Fifth Third Bank', 'banco_de_occidente', 'htb',
                    'orange_rockland', 'Twitter, Inc.', 'Azul', 'ameli_fr', 'typeform', 'cogeco',
                    'banco_inter', 'itunes', 'netsons', 'Three UK', 'bahia', 'test_rite', 'anadolubank',
                    'mbank', 'walmart', 'cloudns', 'crate_and_barrel', 'Boxberry', 'xtrix_tv', 'etrade', 
                    'taxact', 'BNP Paribas', 'ocn', 'Raiffeisen Bank S.A.', 'fnac', 'arnet_tech',
                    'nordea', 'sunrise', 'infinisource', 'paschoalotto', 'grupo_bancolombia', 'youtube',
                    'Banco de Cordoba', 'erste', 'cloudconvert', 'ms_bing', 'EE Limited', 'timeweb',
                    'knab', 'sharp', 'smartsheet', 'bestchange', 'blizzard', 'ms_skype', 'eharmony']
        
        value_list = ['barclays', 'vk', 'lcl', 'dkb', 'labanquepostale', 'laposte', 'ems',
                    'viabcp', 'barclaycard', 'smiles', 'gov', 'cryptobridge', 'ieee', 'strato',
                    'adidas', 'fsnb', 'hfe', 'cetelem', 'Commonwealth Bank of Australia', 'zoominfo',
                    'filetransfer', 'banco', 'postbank', 'airbnb', 'smbc', 'snapchat', 
                    'docmagic', 'halifax', 'gmx', 'sicilshop', 'cathaybank',
                    'otrs', 'mps', 'shoptet', 'tech_target', 'yandex', 'summitbank', 'capitecbank',
                    'rakuten', 'ziggo', 'magazineluiza', 'wp60', 'rabobank', 'latam', 'capitalone',
                    'db', 'qnb', 'momentumoffice', '53', 'bancodeoccidente', 'htb',
                    'oru', 'twitter', 'azul', 'ameli', 'typeform', 'cogeco',
                    'bancointer', 'apple', 'netsons', 'three', 'casasbahia', 'testritegroup', 'anadolubank',
                    'mbank', 'walmart', 'cloudns', 'crateandbarrel', 'boxberry', 'xtrixtv', 'etrade', 
                    'taxact', 'bnpparibas', 'tving', 'raiffeisen', 'fnac', 'arnettechnologies',
                    'nordea', 'sunrise', 'infinisource', 'paschoalotto', 'grupobancolombia', 'youtube',
                    'bancor', 'erstegroup', 'cloudconvert', 'bing', 'ee', 'timeweb',
                    'knab', 'sharp', 'smartsheet', 'bestchange', 'blizzard', 'skype', 'eharmony']
        
        assert len(key_list) == len(value_list)
        
        for i, key in enumerate(key_list):
            self._assign_value(key, value_list[i])
            
        return self.domain_map

In [26]:
domaindict = DomainDict(domain_map)

In [27]:
domain_map_after = domaindict.main()

In [29]:
with open(domain_map_path, 'wb') as handle:
    pickle.dump(domain_map_after, handle)

- Normal phishpedia without credential checking

In [45]:
phish30k_dir = '../phishpedia/benchmark/test30k/phish_sample_30k'
domain_map_path = 'phishpedia/domain_map.pkl'
write_txt = 'results/phishpedia_phish30k_0.83.txt'

with open(write_txt, 'w') as f:
    f.write('folder\t')
    f.write('true_brand\t')   
    f.write('phish_category\t')
    f.write('pred_brand\t')   
    f.write('runtime_element_recognition\t')   
    f.write('runtime_siamese\n')  

In [None]:
for folder in tqdm(os.listdir(phish30k_dir)):
        
    phish_category = 0 # 0 for benign, 1 for phish, by default is benign
    pred_target = None # predicted target, default is None
    
    img_path = os.path.join(phish30k_dir, folder, 'shot.png')
    url = eval(open(os.path.join(phish30k_dir, folder, 'info.txt'), encoding = "ISO-8859-1").read())
    url = url['url'] if isinstance(url, dict) else url
    
    # Element recognition module
    start_time = time.time()
    pred_classes, pred_boxes, pred_scores = element_recognition(img=img_path, model=ele_model)
    ele_recog_time = time.time() - start_time
    
    # If no element is reported
    if len(pred_boxes) == 0:
        phish_category = 0 # Report as benign
        
    # If at least one element is reported
    else: 
        # Phishpedia module
        start_time = time.time()
        pred_target = phishpedia_classifier(pred_classes=pred_classes, pred_boxes=pred_boxes, 
                                            domain_map_path=domain_map_path,
                                            model=pedia_model, 
                                            logo_feat_list=logo_feat_list, file_name_list=file_name_list, 
                                            url=url,
                                            shot_path=img_path,
                                            ts=0.83)
        siamese_time = time.time() - start_time

        # Phishpedia reports target
        if pred_target is not None:
            phish_category = 1 # Report as phish

        # Phishpedia does not report target
        else: # Report as benign
            phish_category = 0
            
    # write to txt file
    with open(write_txt, 'a+') as f:
        f.write(folder+'\t')
        f.write(brand_converter(folder.split('+')[0])+'\t') # true brand
        f.write(str(phish_category)+'\t') # phish/benign
        f.write(brand_converter(pred_target)+'\t') if pred_target is not None else f.write('\t')# phishing target
        # element recognition time
        f.write(str(ele_recog_time)+'\t') 
        #siamese time
        f.write(str(siamese_time)+'\n') if 'siamese_time' in locals() else f.write('\n') 
    
    # delete time variables
    try:
        del ele_recog_time
        del siamese_time
    except:
        pass

 40%|███▉      | 11739/29496 [59:48<1:18:15,  3.78it/s]

- PhishIntention

In [None]:
phish30k_dir = '../phishpedia/benchmark/test30k/phish_sample_30k'
domain_map_path = 'phishpedia/domain_map.pkl'
layout_cfg_dir = 'layout_matcher/configs.yaml'
layout_ref_dir = 'layout_matcher/layout_reference'
layout_ts = 0.5 # TODO: set this ts
write_txt = 'results/phishintention_phish30k_0.83.txt'

with open(write_txt, 'w') as f:
    f.write('folder\t')
    f.write('true_brand\t')   
    f.write('phish_category\t')
    f.write('pred_brand\t')   
    f.write('runtime_element_recognition\t')   
    f.write('runtime_credential_classifier\t')   
    f.write('runtime_siamese\t')  
    f.write('runtime_layout\n')
    
for folder in tqdm(os.listdir(phish30k_dir)):
    
    phish_category = 0 # 0 for benign, 1 for suspicious, 2 for phish
    pred_target = None # predicted target, default is None
    
    img_path = os.path.join(phish30k_dir, folder, 'shot.png')
    url = eval(open(os.path.join(phish30k_dir, folder, 'info.txt'), encoding = "ISO-8859-1").read())
    url = url['url'] if isinstance(url, dict) else url
    
    # Element recognition module
    start_time = time.time()
    pred_classes, pred_boxes, pred_scores = element_recognition(img=img_path, model=ele_model)
    ele_recog_time = time.time() - start_time
    
    # If no element is reported
    if len(pred_boxes) == 0:
        phish_category = 0 # Report as benign
        
    # If at least one element is reported
    else:
        # Credential classifier module
        start_time = time.time()
        cre_pred = credential_classifier(img=img_path, coords=pred_boxes, types=pred_classes, model=cls_model)
        credential_cls_time = time.time() - start_time
        
        # Credential page
        if cre_pred == 0: 
            # Phishpedia module
            start_time = time.time()
            pred_target = phishpedia_classifier(pred_classes=pred_classes, pred_boxes=pred_boxes, 
                                                domain_map_path=domain_map_path,
                                                model=pedia_model, 
                                                logo_feat_list=logo_feat_list, file_name_list=file_name_list,
                                                url=url,
                                                shot_path=img_path,
                                                ts=0.83)
            siamese_time = time.time() - start_time

            # Phishpedia reports target 
            if pred_target is not None:
                # Layout module is only built w.r.t specific brands (social media brands)
                if pred_target not in ['Amazon', 'Facebook', 'Google', 'Instagram', 
                                     'LinkedIn Corporation', 'ms_skype', 'Twitter, Inc.']:
                    phish_category = 2 # Report as phish
                
                else: 
                    layout_cfg, gt_coords_arr, gt_files_arr, gt_shot_size_arr = layout_config(cfg_dir=layout_cfg_dir, 
                                                                                       ref_dir=layout_ref_dir, 
                                                                                       matched_brand=pred_target)
                    start_time = time.time()
                    max_s, max_site = layout_matcher(pred_boxes=pred_boxes, img=img_path, 
                                   gt_coords_arr=gt_coords_arr, gt_files_arr=gt_files_arr, gt_shot_size_arr=gt_shot_size_arr,
                                   cfg=layout_cfg)
                    layout_time = time.time() - start_time

                    # Success layout match
                    if max_s >= layout_ts: 
                        phish_category = 2 # Report as phish

                    # Un-successful layout match
                    else: 
                        phish_category = 1 # Report as suspicious

            # Phishpedia does not report target
            else: # Report as benign
                phish_category = 0

        # Non-credential page
        elif cre_pred == 1: 
            # TODO: dynamic module here
            phish_category = 0 # Report as benign
            
    # write to txt file
    with open(write_txt, 'a+') as f:
        f.write(folder+'\t')
        f.write(brand_converter(folder.split('+')[0])+'\t') # true brand
        f.write(str(phish_category)+'\t') # phish/benign/suspicious
        f.write(brand_converter(pred_target)+'\t') if pred_target is not None else f.write('\t')# phishing target
        # element recognition time
        f.write(str(ele_recog_time)+'\t') 
        # credential classifier time
        f.write(str(credential_cls_time)+'\t') if 'credential_cls_time' in locals() else f.write('\t') 
        # siamese time
        f.write(str(siamese_time)+'\t') if 'siamese_time' in locals() else f.write('\t') 
        # layout time
        f.write(str(layout_time)+'\n') if 'layout_time' in locals() else f.write('\n') 
        
    # delete time variables
    try:
        del ele_recog_time
        del credential_cls_time
        del siamese_time
        del layout_time
    except:
        pass
            

- Test single site

In [24]:
img_path = 'datasets/train_imgs/Amazon.com Inc.+2020-08-16-15`54`25.png'
pred_classes, pred_boxes, pred_scores = element_recognition(img=img_path, model=ele_model)

In [25]:
cls_pred = credential_classifier(img_path, pred_boxes, pred_classes, cls_model)

In [26]:
domain_map_path = 'phishpedia/domain_map.pkl'

pred_target = phishpedia_classifier(pred_classes=pred_classes, pred_boxes=pred_boxes, 
                                    domain_map_path=domain_map_path,
                                    model=pedia_model, 
                                    logo_feat_list=logo_feat_list, file_name_list=file_name_list, 
                                    shot_path=img_path,
                                    url='https://www.kkk.com',
                                    ts=0.83)


(3123,)


In [33]:
layout_cfg, gt_coords_arr, gt_files_arr, gt_shot_size_arr = layout_config(cfg_dir='layout_matcher/configs.yaml', 
                                                                   ref_dir='layout_matcher/layout_reference', 
                                                                   matched_brand=pred_target)


In [41]:
max_s, max_site = layout_matcher(pred_boxes=pred_boxes, img=img_path, 
                                 gt_coords_arr=gt_coords_arr, gt_files_arr=gt_files_arr, gt_shot_size_arr=gt_shot_size_arr,
                                 cfg=layout_cfg)

In [33]:
# check = cv2.imread(img_path)
# for j, box in enumerate(pred_boxes):
#     cv2.rectangle(check, (box[0], box[1]), (box[2], box[3]), (36, 255, 12), 2)
#     cv2.putText(check, str(pred_classes[j].item()), (box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

# plt.figure(figsize=(20,20))
# plt.imshow(check[:, :, ::-1])

## Evaluation

In [None]:
import pandas as pd

In [None]:
def evaluate(result_path, gt_type):
    assert gt_type in ['phish', 'benign'] # gt type is either phish/benign
    result = open(result_path).readlines()
    result_df = pd.DataFrame(result)
    