In [1]:
%load_ext autoreload
%autoreload 2
!source /home/murilo/RelNetCare/.env

In [45]:
from tqdm import tqdm
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import glob
import json
import pickle
import numpy as np
import json
import re
from src.paths import LOCAL_PROCESSED_DATA_PATH

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize BERT model and tokenizer
print("Initializing BERT model and tokenizer...")
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertModel.from_pretrained('bert-large-uncased')
max_length = tokenizer.model_max_length  # Get the max_length from the tokenizer
print("max_length=",max_length)
print(f"Moving model into device={device}...")
model.to(device)

# Placeholder for storing results
df_list = []

# Batch size
batch_size = 8
exceed_count = 0  # Counter for dialogues exceeding max_length
total_count = 0   # Total number of dialogues

# Loop over JSON files
data_path = LOCAL_PROCESSED_DATA_PATH / 'dialog-re-ddrel'
print(f"Starting to loop over JSON files in {data_path}...")
for file in glob.glob(str(data_path / "*.json")):
    print(f"Reading {file}...")
    with open(file, 'r') as f:
        data = json.load(f)

    dialogues = ['\n'.join(d[0]) for d in data]
    # dialogues = [re.sub(r"Speaker (\d+):", r"S\1:", ' '.join(d[0])) for d in data]
    # dialogues = [re.sub(r"Speaker (\d+):", r"S\1:", json.dumps(d[0], indent=0)) for d in data]
    relations = [re.sub(r"Speaker (\d+):", r"S\1:", json.dumps(d[1], indent=0)) for d in data]

    # Break data into batches
    print("Breaking data into batches...")
    num_batches = int(np.ceil(len(dialogues) / float(batch_size)))
    pbar = tqdm(range(num_batches), desc="Processing batches")
    for i in pbar:
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(dialogues))

        dialogues_batch = dialogues[start_idx:end_idx]
        relations_batch = relations[start_idx:end_idx]

        # Check if any token lengths exceed max length
        for dialogue in dialogues_batch:
            total_count += 1
            input_len = len(tokenizer.encode(dialogue))
            if input_len > max_length:
                exceed_count += 1

        pct_exceed = (exceed_count / total_count) * 100
        pbar.set_postfix({"% Exceeding Max Length": f"{pct_exceed:.2f}% ({exceed_count} / {total_count})"}, refresh=True)

        # Tokenize and generate tensor for each batch
        inputs = tokenizer(dialogues_batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt")  # Set max_length here
        inputs = {key: val.to(device) for key, val in inputs.items()}

        # Generate BERT embeddings
        with torch.no_grad():
            outputs = model(**inputs)

        # Get the [CLS] token embedding
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

        for j in range(len(dialogues_batch)):
            df_list.append({
                'dialogue': dialogues_batch[j],
                'relation': relations_batch[j],
                'embedding': cls_embeddings[j],
                'dataset': str(file).split('/')[-1].replace('.json','')
            })


# Re-tokenize dialogues and check if they exceed max_length
print("Checking for exceeeded max length...")
for i, row in enumerate(df_list):
    input_len = len(tokenizer.encode(row['dialogue']))
    exceed_flag = input_len > max_length
    df_list[i]['exceeded_max_length'] = exceed_flag

# Create DataFrame
print("Creating DataFrame...")
df = pd.DataFrame(df_list)

# Create DataFrame
print("Creating DataFrame...")
df = pd.DataFrame(df_list)

# Save as a pickle file
print(f"Saving DataFrame to {data_path / 'df_embeddings.pkl'}...")
df.to_pickle(data_path / "df_embeddings_original.pkl")
print("Done!")


Using device: cuda
Initializing BERT model and tokenizer...
max_length= 512
Moving model into device=cuda...
Starting to loop over JSON files in /home/murilo/RelNetCare/data/processed/dialog-re-ddrel...
Reading /home/murilo/RelNetCare/data/processed/dialog-re-ddrel/train.json...
Breaking data into batches...


Processing batches:   0%|          | 2/764 [00:00<03:41,  3.43it/s, % Exceeding Max Length=0.00% (0 / 16)]Token indices sequence length is longer than the specified maximum sequence length for this model (559 > 512). Running this sequence through the model will result in indexing errors
Processing batches: 100%|██████████| 764/764 [02:27<00:00,  5.18it/s, % Exceeding Max Length=1.95% (119 / 6110)]


Reading /home/murilo/RelNetCare/data/processed/dialog-re-ddrel/dev.json...
Breaking data into batches...


Processing batches: 100%|██████████| 127/127 [00:26<00:00,  4.83it/s, % Exceeding Max Length=1.99% (142 / 7121)]


Reading /home/murilo/RelNetCare/data/processed/dialog-re-ddrel/test.json...
Breaking data into batches...


Processing batches: 100%|██████████| 121/121 [00:25<00:00,  4.73it/s, % Exceeding Max Length=2.03% (164 / 8088)]


Checking for exceeeded max length...
Creating DataFrame...
Creating DataFrame...
Saving DataFrame to /home/murilo/RelNetCare/data/processed/dialog-re-ddrel/df_embeddings.pkl...
Done!


Unnamed: 0,dialogue,relation,embedding,dataset,exceeded_max_length
0,Speaker 1: It's been an hour and not one of my...,"[\n{\n""x"": ""Speaker 2"",\n""y"": ""Chandler Bing"",...","[0.54284745, -0.70306915, -1.3126775, -0.05009...",train,False
1,"Speaker 1: So, eh... it's probably gonna be ha...","[\n{\n""x"": ""Speaker 2"",\n""y"": ""Boston"",\n""rid""...","[0.049894452, -0.66197634, -0.48381886, -0.193...",train,False
2,Speaker 1: Hi!\nSpeaker 2: Hi!\nSpeaker 1: So ...,"[\n{\n""x"": ""Speaker 2"",\n""y"": ""goodie-goodie"",...","[0.31069955, -0.4959651, -0.99742293, -0.19498...",train,False
3,Speaker 1: Hi.\nSpeaker 2: Hi.\nSpeaker 1: I j...,"[\n{\n""x"": ""Phoebe"",\n""y"": ""Mike"",\n""rid"": [\n...","[0.41672152, -0.71111834, -1.2734727, -0.16772...",train,False
4,"Speaker 1: 'Okay. Okay, daddy we'll see you to...","[\n{\n""x"": ""Speaker 2"",\n""y"": ""wethead"",\n""rid...","[0.25623024, -0.26421997, -1.0316484, -0.06013...",train,False
...,...,...,...,...,...
8083,Speaker 1: I feel like you're turning me into ...,"[\n{\n""x"": ""Speaker 1"",\n""x_type"": ""PER"",\n""y""...","[0.20662983, -0.22525145, -0.7552583, -0.35369...",test,False
8084,Speaker 1: You have to go. I mean it.\nSpeake...,"[\n{\n""x"": ""Speaker 1"",\n""x_type"": ""PER"",\n""y""...","[0.1891077, -0.5775761, -0.543093, -0.07402149...",test,False
8085,"Speaker 1: You're crazier than I thought, Lenn...","[\n{\n""x"": ""Speaker 1"",\n""x_type"": ""PER"",\n""y""...","[0.07690117, -0.59195894, -0.9255361, 0.373093...",test,False
8086,"Speaker 1: What's going on?\nSpeaker 2: Faith,...","[\n{\n""x"": ""Speaker 1"",\n""x_type"": ""PER"",\n""y""...","[-0.2005296, -0.42849943, -1.0686105, -0.01745...",test,False
