In [None]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
from transformers import DistilBertTokenizer, DistilBertModel
import os
from tqdm.notebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from PIL import Image

class PhishingDataset(Dataset):
    def __init__(self, h5_file, split='train'):
        self.file = h5py.File(h5_file, 'r')
        self.urls = self.file[f'{split}/urls'][:]
        self.screenshots = self.file[f'{split}/screenshots'][:]
        self.labels = self.file[f'{split}/labels'][:]
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        url = self.urls[idx].decode('utf-8')
        screenshot = self.screenshots[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        # Preprocess the screenshot
        image = self.transform(screenshot)

        # Tokenize the URL
        encoded_input = self.tokenizer(
            url,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )

        return {
            'input_ids': encoded_input['input_ids'].squeeze(),
            'attention_mask': encoded_input['attention_mask'].squeeze(),
            'image': image,
            'label': label
        }

class PhishingClassifier(nn.Module):
    def __init__(self):
        super(PhishingClassifier, self).__init__()
        # Pre-trained ResNet for image feature extraction
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove the classification layer
        
        # Pre-trained DistilBERT for URL feature extraction
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        
        # Classifier combining both CNN and BERT features
        self.classifier = nn.Sequential(
            nn.Linear(self.cnn.fc.in_features + self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )

    def forward(self, input_ids, attention_mask, image):
        # Extract features from the screenshot using CNN
        image_features = self.cnn(image)
        
        # Extract features from the URL using BERT
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        url_features = bert_outputs.last_hidden_state[:, 0, :]
        
        # Concatenate image and URL features
        combined_features = torch.cat((image_features, url_features), dim=1)
        
        # Classify the combined features
        logits = self.classifier(combined_features)
        return logits

def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask, images)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    accuracy = accuracy_score(all_labels, all_preds)

    return precision, recall, f1, accuracy

def train_model(dataset_path, epochs=3, batch_size=8, learning_rate=2e-5):
    # Initialize dataset and dataloaders
    train_dataset = PhishingDataset(dataset_path, split='train')
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = PhishingDataset(dataset_path, split='test')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
    model = PhishingClassifier().to(device)

    # Define optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss}")

        # Evaluation after each epoch
        precision, recall, f1, accuracy = evaluate_model(model, test_dataloader, device)
        print(f"Epoch {epoch + 1}/{epochs}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, Accuracy: {accuracy:.4f}")

        # Save model after each epoch
        torch.save(model.state_dict(), f"cnn_phishing_classifier_epoch_{epoch + 1}.pt")

    return model


dataset_path = os.path.expanduser("~/transfer/phishing_output.h5")
model = train_model(dataset_path)
torch.save(model.state_dict(), "cnn_phishing_classifier.pt")
