Preprocess an HDF5 file that was created in datasetcreation.ipynb by cleaning up the HTML, then tokenizing the content into input IDs and attention masks.
Save the processed data into a new HDF5 file. This speeds up iterations because the dataset objects don't need to tokenize text at training time.

In [None]:
import h5py
import numpy as np
from transformers import DistilBertTokenizer
from tqdm import tqdm
from custom_html_parser import CustomHTML2Text

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

original_file_path = '/Users/imack/transfer/phishing_output.h5'
new_file_path = '/Users/imack/transfer/phishing_output_tokenized.h5'

converter = CustomHTML2Text()

def tokenize_with_overlap(html_content, max_chunk_length=512, stride=256):
    if not html_content:
        html_content = '<html></html>'
    
    tokens = tokenizer(html_content, add_special_tokens=False, return_tensors='np')["input_ids"].squeeze()

    chunks = []
    attention_masks = []
    for i in range(0, len(tokens), stride):
        chunk = tokens[i:i + max_chunk_length]
        padded_chunk = np.pad(
            chunk,
            (0, max(0, max_chunk_length - len(chunk))), 
            constant_values=tokenizer.pad_token_id
        )
        chunks.append(padded_chunk)
        attention_mask = [1] * len(chunk) + [0] * (max_chunk_length - len(chunk))
        attention_masks.append(attention_mask)
    
    # Convert to arrays of consistent shape
    chunks = np.array(chunks, dtype=np.int32)
    attention_masks = np.array(attention_masks, dtype=np.int32)
    
    return chunks, attention_masks

with h5py.File(original_file_path, 'r') as original_file, h5py.File(new_file_path, 'a') as outfile:
    # Iterate through slices (train, dev, test)
    for slice_name in original_file.keys():
        slice_group = original_file[slice_name]
        print(f"Processing slice: {slice_name}")
        
        if slice_name not in outfile:
            new_group = outfile.create_group(slice_name)
            new_group.create_dataset('labels', data=slice_group['labels'][:])
            new_group.create_dataset('urls', data=slice_group['urls'][:])
            #new_group.create_dataset('screenshots', data=slice_group['screenshots'][:])
            #new_group.create_dataset('html_content', data=slice_group['html_content'][:])
            new_group.create_dataset('last_processed_index', data=np.array([-1], dtype=np.int32))
        else:
            new_group = outfile[slice_name]
        
        # Load last processed index
        last_processed_index = new_group['last_processed_index'][0]
        
        html_contents = slice_group['html_content'][:]
        urls = slice_group['urls'][:]
        
        # Create new datasets for processed data if they don't exist
        if 'html_input_ids' not in new_group:
            input_ids_dataset = new_group.create_dataset(
                "html_input_ids", 
                shape=(0,), 
                maxshape=(None,), 
                dtype=h5py.special_dtype(vlen=np.dtype('int32'))
            )
            attention_masks_dataset = new_group.create_dataset(
                "html_attention_masks", 
                shape=(0,), 
                maxshape=(None,), 
                dtype=h5py.special_dtype(vlen=np.dtype('int32'))
            )
            url_input_ids_dataset = new_group.create_dataset(
                "url_input_ids", 
                shape=(0, 128),  
                maxshape=(None, 128), 
                dtype=np.int32
            )
            url_attention_masks_dataset = new_group.create_dataset(
                "url_attention_masks", 
                shape=(0, 128), 
                maxshape=(None, 128), 
                dtype=np.int32
            )
            html_content_dataset = new_group.create_dataset(
                "html_content",
                shape=(0,),  # Start with zero entries
                maxshape=(None,),  # Allow unlimited resizing along the first axis
                dtype=h5py.string_dtype(encoding="utf-8"),  # Variable-length string type
                chunks=(1,)  # Enable chunking, specify chunk size
            )
        else:
            input_ids_dataset = new_group['html_input_ids']
            attention_masks_dataset = new_group['html_attention_masks']
            url_input_ids_dataset = new_group['url_input_ids']
            url_attention_masks_dataset = new_group['url_attention_masks']
            html_content_dataset = new_group['html_content']
        
        print(f"last_processed_index: {last_processed_index}")
        # Resume processing from the last processed index
        for i in tqdm(range(last_processed_index + 1, len(html_contents)), total=len(html_contents)):
            html_content = html_contents[i].decode('utf-8')        
            
            plain_text = converter.handle(html_content)
            
            chunks, attention_masks = tokenize_with_overlap(plain_text)
            url = urls[i].decode('utf-8')
            
            flat_input_ids = np.concatenate(chunks).astype(np.int32)
            flat_attention_masks = np.concatenate(attention_masks).astype(np.int32)
            
            input_ids_dataset.resize((input_ids_dataset.shape[0] + 1,))
            attention_masks_dataset.resize((attention_masks_dataset.shape[0] + 1,))
            
            input_ids_dataset[-1] = flat_input_ids
            attention_masks_dataset[-1] = flat_attention_masks
            
            encoded_url_input = tokenizer(
                url,
                padding='max_length',
                truncation=True,
                max_length=128,
                return_tensors='np'
            )
            
            html_content_dataset.resize((html_content_dataset.shape[0] + 1, ))
            html_content_dataset[-1] = plain_text
            
            url_input_ids_dataset.resize((url_input_ids_dataset.shape[0] + 1, 128))
            url_input_ids_dataset[-1, :] = encoded_url_input['input_ids']
            
            url_attention_masks_dataset.resize((url_attention_masks_dataset.shape[0] + 1, 128))
            url_attention_masks_dataset[-1, :] = encoded_url_input['attention_mask']
            
            new_group['last_processed_index'][0] = i
            outfile.flush() 

print("Preprocessing complete. Saved to", new_file_path)


Processing slice: dev
last_processed_index: -1


  0%|          | 0/7126 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2327 > 512). Running this sequence through the model will result in indexing errors
 10%|▉         | 689/7126 [00:48<05:13, 20.51it/s]