In [1]:
import torch
import torch.nn as nn
from torchvision import models
from transformers import DistilBertModel
from test_harness import test_harness

class PhishingClassifier(nn.Module):
    def __init__(self):
        super(PhishingClassifier, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn_fc_features = self.cnn.fc.in_features
        self.cnn.fc = nn.Identity() 
        
        # Pre-trained DistilBERT for URL feature extraction
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        
        # Classifier combining both CNN and BERT features
        self.classifier = nn.Sequential(
            nn.Linear(self.cnn_fc_features + self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )

    def forward(self, url_input_ids, url_attention_mask, image):
        image_features = self.cnn(image)
        
        bert_outputs = self.bert(input_ids=url_input_ids, attention_mask=url_attention_mask)
        url_features = bert_outputs.last_hidden_state[:, 0, :]
        
        combined_features = torch.cat((image_features, url_features), dim=1)
        
        logits = self.classifier(combined_features)
        return logits
    
    def test_name(self):
        return 'cnn_with_url'

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model = PhishingClassifier().to(device)
test_harness(model, epochs=5)
torch.save(model.state_dict(), f"{model.test_name()}_phishing_classifier.pt")





Epoch 1/5:   0%|          | 0/7080 [00:00<?, ?it/s]

Epoch 1/5, Average Loss: 0.11199124391476155
Epoch 1/5, Test Loss: 0.0626, Precision: 0.9687, Recall: 0.9727, F1 Score: 0.9707, Accuracy: 0.9777


Epoch 2/5:   0%|          | 0/7080 [00:00<?, ?it/s]

Epoch 2/5, Average Loss: 0.053912262396750385
Epoch 2/5, Test Loss: 0.0619, Precision: 0.9696, Recall: 0.9778, F1 Score: 0.9737, Accuracy: 0.9800


Epoch 3/5:   0%|          | 0/7080 [00:00<?, ?it/s]

Epoch 3/5, Average Loss: 0.03008491509588227
Epoch 3/5, Test Loss: 0.0557, Precision: 0.9832, Recall: 0.9708, F1 Score: 0.9770, Accuracy: 0.9826


Epoch 4/5:   0%|          | 0/7080 [00:00<?, ?it/s]

Epoch 4/5, Average Loss: 0.019016353558725385
Epoch 4/5, Test Loss: 0.0848, Precision: 0.9827, Recall: 0.9634, F1 Score: 0.9730, Accuracy: 0.9797


Epoch 5/5:   0%|          | 0/7080 [00:00<?, ?it/s]

Epoch 5/5, Average Loss: 0.013983464400673521
Epoch 5/5, Test Loss: 0.0857, Precision: 0.9686, Recall: 0.9793, F1 Score: 0.9739, Accuracy: 0.9801
