In [1]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertModel
import os
from tqdm.notebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from test_harness import evaluate_model

class PhishingDataset(Dataset):
    def __init__(self, h5_file, split='train'):
        self.file = h5py.File(h5_file, 'r')
        self.urls = self.file[f'{split}/urls'][:]
        self.html_content = self.file[f'{split}/html_content'][:]
        self.labels = self.file[f'{split}/labels'][:]
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        url = self.urls[idx].decode('utf-8')
        html_content = self.html_content[idx].decode('utf-8')
        text = f"URL: {url} CONTENT: {html_content}"
        encoded_input = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return {
            'input_ids': encoded_input['input_ids'].squeeze(),
            'attention_mask': encoded_input['attention_mask'].squeeze(),
            'label': label
        }

class PhishingClassifier(nn.Module):
    def __init__(self):
        super(PhishingClassifier, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state[:, 0, :])
        return logits

def train_model(dataset_path, epochs=3, batch_size=32, learning_rate=0.00002):
    # Initialize dataset and dataloaders
    train_dataset = PhishingDataset(dataset_path, split='train')
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = PhishingDataset(dataset_path, split='test')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
    model = PhishingClassifier().to(device)

    # Define optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss}")

        # Evaluation after each epoch
        loss, precision, recall, f1, accuracy = evaluate_model(model, test_dataloader, device, criterion)
        print(f"Epoch {epoch + 1}/{epochs}, Test Loss: {loss:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, Accuracy: {accuracy:.4f}")

        # Save model after each epoch
        torch.save(model.state_dict(), f"phishing_classifier_epoch_{epoch + 1}.pt")

    return model


dataset_path = os.path.expanduser("~/transfer/phishing_output.h5")
model = train_model(dataset_path)
torch.save(model.state_dict(), "phishing_classifier.pt")




Epoch 1/3:   0%|          | 0/1770 [00:00<?, ?it/s]

Epoch 1/3, Average Loss: 0.11464344092308001


TypeError: PhishingClassifier.forward() missing 2 required positional arguments: 'input_ids' and 'attention_mask'