# Finetuning BERT for URL

In [1]:
pip install -r requirements.txt

Collecting torchaudio (from -r requirements.txt (line 18))
  Downloading torchaudio-2.2.2-cp312-cp312-win_amd64.whl.metadata (6.4 kB)
Collecting torchvision (from -r requirements.txt (line 19))
  Downloading torchvision-0.17.2-cp312-cp312-win_amd64.whl.metadata (6.6 kB)
Downloading torchaudio-2.2.2-cp312-cp312-win_amd64.whl (2.4 MB)
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   ---------------------------------------- 0.0/2.4 MB ? eta -:--:--
   - -------------------------------------- 0.1/2.4 MB 1.2 MB/s eta 0:00:02
   --- ------------------------------------ 0.2/2.4 MB 1.7 MB/s eta 0:00:02
   ----- ---------------------------------- 0.3/2.4 MB 2.1 MB/s eta 0:00:01
   -------- ------------------------------- 0.5/2.4 MB 2.4 MB/s eta 0:00:01
   ---------- ----------------------------- 0.6/2.4 MB 2.5 MB/s eta 0:00:01
   -------------- ------------------------- 0.8/2.4 MB 2.6 MB/s eta 0:00:01
   ---------------- ----------------------- 1.0/2.4 MB 2.7 MB/s eta 0:0

### Data preprocessing

In [15]:
import pandas as pd
from urllib.parse import urlparse

# Function to preprocess and extract domain from URLs
def extract_domain(url):
    # Check if the URL is not a string (e.g., NaN or None)
    if not isinstance(url, str):
        return ""  # Return an empty string to indicate no domain
        # Clean the URL by removing slashes and quotation marks
    
    # Extract the domain
    parsed_url = urlparse(url)
    domain = parsed_url.netloc or parsed_url.path  # Fallback to path if netloc is empty (e.g., relative URLs)
    domain = domain.replace('www.', '')  # Removing 'www.' for consistency
    domain = domain.replace('/', '').replace('"', '')
    return domain

# Load the datasets
dataset_query = pd.read_csv('dataset_incl_query.csv')
dataset_scraped = pd.read_csv('search_results_DDG.csv')

# Merge datasets on 'EntityNumber'
merged_dataset = pd.merge(dataset_query[['EntityNumber', 'URL', 'SearchQuery']], dataset_scraped, on='EntityNumber')

# Preprocess URLs to extract domains
merged_dataset['CorrectDomain'] = merged_dataset['URL'].apply(extract_domain)
for i in range(1, 6):
    merged_dataset[f'URL{i}Domain'] = merged_dataset[f'URL{i}'].apply(extract_domain)

# Prepare labels: If the correct domain matches one of the scraped domains, label with that index; otherwise, label as -1
# Adjust the labeling function to handle multiple correct URLs
def mark_correct_labels(row):
    labels = []
    for i in range(1, 6):
        # Check if each scraped domain matches the correct domain
        if row['CorrectDomain'] == row[f'URL{i}Domain']:
            labels.append(1)  # Mark as correct
        else:
            labels.append(0)  # Mark as incorrect
    return labels

# Apply the function to each row in the merged dataset
merged_dataset['Labels'] = merged_dataset.apply(mark_correct_labels, axis=1)

# Display the updated dataset with domains and new labels for inspection
print(merged_dataset[['EntityNumber', 'SearchQuery', 'CorrectDomain', 'URL1Domain', 'URL2Domain', 'URL3Domain', 'URL4Domain', 'URL5Domain', 'Labels']].head())


   EntityNumber                                        SearchQuery  \
0  0201.310.929                                      IGL 3600 Genk   
1  0202.239.951                           PROXIMUS 1030 Schaarbeek   
2  0203.201.340             Nationale Bank van België 1000 Brussel   
3  0206.460.639  Intergemeentelijk Samenwerkingsverband van het...   
4  0206.653.946  Rijksinstituut voor Ziekte- en Invaliditeitsve...   

           CorrectDomain             URL1Domain      URL2Domain  \
0  extranet.iglimburg.be           iglimburg.be  intergalva.com   
1           proximus.com           proximus.com     proximus.be   
2                 nbb.be                 nbb.be          nbb.be   
3           interwaas.be  erfgoedcelwaasland.be         vvsg.be   
4          inami.fgov.be          riziv.fgov.be   riziv.fgov.be   

      URL3Domain               URL4Domain         URL5Domain           Labels  
0   mapcarta.com       roamtechnology.com         geruro.com  [0, 0, 0, 0, 0]  
1    proximus.be

In [22]:
merged_dataset = merged_dataset.drop(['URL', 'URL1', 'URL2', 'URL3', 'URL4', 'URL5'], axis=1)
print(merged_dataset.head())

   EntityNumber                                        SearchQuery  \
0  0201.310.929                                      IGL 3600 Genk   
1  0202.239.951                           PROXIMUS 1030 Schaarbeek   
2  0203.201.340             Nationale Bank van België 1000 Brussel   
3  0206.460.639  Intergemeentelijk Samenwerkingsverband van het...   
4  0206.653.946  Rijksinstituut voor Ziekte- en Invaliditeitsve...   

           CorrectDomain             URL1Domain      URL2Domain  \
0  extranet.iglimburg.be           iglimburg.be  intergalva.com   
1           proximus.com           proximus.com     proximus.be   
2                 nbb.be                 nbb.be          nbb.be   
3           interwaas.be  erfgoedcelwaasland.be         vvsg.be   
4          inami.fgov.be          riziv.fgov.be   riziv.fgov.be   

      URL3Domain               URL4Domain         URL5Domain           Labels  
0   mapcarta.com       roamtechnology.com         geruro.com  [0, 0, 0, 0, 0]  
1    proximus.be

### Data prepreration for BERT

In [23]:
class URLDomainDataset(Dataset):
    def __init__(self, queries, domains, labels, tokenizer, max_len=512):
        self.queries = queries
        self.domains = domains
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        query = self.queries[idx]
        labels = self.labels[idx]
        input_ids_list = []
        attention_mask_list = []

        for domain in self.domains[idx]:
            # Tokenize query and domain
            inputs = self.tokenizer.encode_plus(
                query + " [SEP] " + domain,
                add_special_tokens=True,
                max_length=self.max_len,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            input_ids_list.append(inputs["input_ids"].squeeze(0))
            attention_mask_list.append(inputs["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids_list)
        attention_mask = torch.stack(attention_mask_list)
        labels = torch.tensor(labels, dtype=torch.float)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


In [24]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Extracting data
queries = merged_dataset['SearchQuery'].tolist()
domains = merged_dataset[[f'URL{i}Domain' for i in range(1, 6)]].values.tolist()
labels = merged_dataset['Labels'].tolist()

# Initializing Dataset
dataset = URLDomainDataset(queries, domains, labels, tokenizer)

# Initializing DataLoader
loader = DataLoader(dataset, batch_size=8, shuffle=True)


### Model definition

In [None]:
from transformers import BertForSequenceClassification, AdamW
import torch

# Assuming you're using a GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=5,  # Assuming you have 5 URL domains to classify per query
    problem_type="multi_label_classification",  # Specify the problem type
).to(device)

# Initialize the AdamW optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

#### Training loop

In [None]:
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm  # For nice progress bars

# Specify the number of epochs
epochs = 3

# Use binary cross-entropy with logits as our loss function
loss_func = BCEWithLogitsLoss()

model.train()
for epoch in range(epochs):
    loop = tqdm(loader, leave=True)  # Create a progress bar
    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        logits = outputs.logits
        
        loss = loss_func(logits, batch['labels'])
        
        loss.backward()  # Backpropagation
        optimizer.step()
        optimizer.zero_grad()
        
        loop.set_description(f'Epoch {epoch+1}')
        loop.set_postfix(loss=loss.item())


## Model Evaluation

In [None]:
# Assuming you have a validation DataLoader named `val_loader`
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    # Similar loop for validation
    # Here, you would calculate your desired metrics based on model predictions
    pass


### Save model

In [None]:
model.save_pretrained("./BERT_model/model")
tokenizer.save_pretrained("./BERT_model/tokenizer")
