# SharedPrivateArchitecture dataset by dataset

In [1]:

# Importing from the src directory
import sys
sys.path.append('./src')  # Ensure src is in the Python path

# Import necessary modules and functions
import torch
from torch.utils.data import DataLoader, random_split
from data import HumorDataset, sharedprivate_load_and_split_data
from models import SharedPrivateModel
from training import sharedprivate_train_shared_bert, sharedprivate_train_private_model
from utils import sharedprivate_eval_model
from transformers import BertTokenizer, BertForMaskedLM, BertModel
import pandas as pd
import os
from tqdm import tqdm
from transformers import BertTokenizer


In [2]:
# Cell 2: Set Up Device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: mps


In [3]:
# Cell 3: Define Parameters
MAX_LENGTH = 64
BATCH_SIZE = 32  # Adjusted for memory considerations
EPOCHS_SHARED = 3  # Number of epochs for shared BERT fine-tuning
EPOCHS_PRIVATE = 3  # Number of epochs for private BERT fine-tuning
LEARNING_RATE = 0.00001
NUM_LABELS = 2  # Humorous or not


In [4]:
# Cell 4: Paths to Models and Data

# Paths to models
SHARED_MODEL_PATH = './models/bert-classification'    # Path to your saved shared BERT model
PRIVATE_MODEL_PATH = './models/bert-mlm'             # Path to your pre-trained BERT-MLM model
TOKENIZER_NAME = 'bert-base-uncased'                 # Name of your tokenizer

# Path to dataset CSV file
CSV_FILE_PATH = './data/shared_private_dataset.csv'


In [5]:
# Cell 5: Load and Split Data
print("Loading and splitting data...")
train_df, val_df, test_df, humor_type_to_idx = sharedprivate_load_and_split_data(CSV_FILE_PATH)

# Original list of humor types
humor_types = list(humor_type_to_idx.keys())
print(f"Original Humor types: {humor_types}")
print(f"Original Humor type to index mapping: {humor_type_to_idx}")


# Count humor types in the filtered train, validation, and test datasets
print("Counting filtered humor types in datasets...")

def count_humor_types(df, dataset_name):
    if 'type' in df.columns:  # Use 'type' as the column name for humor type
        counts = df['type'].value_counts()
        print(f"\n{dataset_name} Humor Type Counts:")
        print(counts)
    else:
        print(f"Column 'type' not found in {dataset_name} dataset.")

count_humor_types(train_df, "Training")
count_humor_types(val_df, "Validation")
count_humor_types(test_df, "Test")

Loading and splitting data...
Original Humor types: ['body punchlines', 'news headlines', 'puns', 'storylines']
Original Humor type to index mapping: {'body punchlines': 0, 'news headlines': 1, 'puns': 2, 'storylines': 3}
Counting filtered humor types in datasets...

Training Humor Type Counts:
body punchlines    14984
storylines         14824
puns               14798
news headlines     14791
Name: type, dtype: int64

Validation Humor Type Counts:
storylines         3222
puns               3193
body punchlines    3189
news headlines     3124
Name: type, dtype: int64

Test Humor Type Counts:
puns               3256
storylines         3198
body punchlines    3152
news headlines     3122
Name: type, dtype: int64


In [6]:
# Cell 6: Create Datasets and Dataloaders
print("Creating datasets and dataloaders...")

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained(TOKENIZER_NAME)

# Create datasets
train_dataset = HumorDataset(train_df, TOKENIZER_NAME, MAX_LENGTH)
val_dataset = HumorDataset(val_df, TOKENIZER_NAME, MAX_LENGTH)
test_dataset = HumorDataset(test_df, TOKENIZER_NAME, MAX_LENGTH)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Print dataset sizes
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")


Creating datasets and dataloaders...
Number of training samples: 59397
Number of validation samples: 12728
Number of test samples: 12728


In [7]:
# Cell 7: Initialize and Train Shared BERT Model
print("Initializing and training the shared BERT model...")

# Define the directory to save the shared BERT fine-tuned model
shared_model_save_dir = './models/shared_bert_finetuned'
if not os.path.exists(shared_model_save_dir):
    os.makedirs(shared_model_save_dir)

# Define private model paths for each humor type
# Assuming you have separate pre-trained MLM BERT models for each humor type.
# Alternatively, you can use the same MLM-pretrained BERT for all private layers initially.
# Here, we'll assume the same MLM BERT is used for all, but you can adjust as needed.

private_model_paths = {humor_type_to_idx[ht]: PRIVATE_MODEL_PATH for ht in humor_types}

# Initialize the SharedPrivateModel
shared_model = SharedPrivateModel(
    shared_model_path=SHARED_MODEL_PATH,
    private_model_paths=private_model_paths,  # Dictionary mapping humor type indices to private BERT paths
    num_labels=NUM_LABELS
)
def load_model_weights(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    
    # Load shared BERT weights
    shared_bert_state_dict = {
        k.replace("shared_bert.", ""): v
        for k, v in checkpoint.items()
        if k.startswith("shared_bert.")
    }
    model.shared_bert.load_state_dict(shared_bert_state_dict, strict=False)
    print("[INFO] Shared BERT weights loaded successfully.")

    # Load private BERT weights
    for humor_type_idx, private_model in model.private_bert_dict.items():
        private_bert_state_dict = {
            k.replace(f"private_bert_dict.{humor_type_idx}.", ""): v
            for k, v in checkpoint.items()
            if k.startswith(f"private_bert_dict.{humor_type_idx}.")
        }
        private_model.load_state_dict(private_bert_state_dict, strict=False)
    print("[INFO] Private BERT weights loaded successfully.")

    # Load classifier weights
    classifier_state_dict = {
        k.replace("classifier.", ""): v
        for k, v in checkpoint.items()
        if k.startswith("classifier.")
    }
    model.classifier.load_state_dict(classifier_state_dict, strict=False)
    print("[INFO] Classifier weights loaded successfully.")

    return model
    
shared_mdel = load_model_weights(shared_model , './models/shared_bert_finetuned/shared_private_best_model.pt')

shared_model = shared_model.to(device)

# Train the shared-private model (which effectively fine-tunes the private BERT and classifier)
shared_model = sharedprivate_train_shared_bert(
    model=shared_model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=EPOCHS_SHARED,
    learning_rate=LEARNING_RATE,
    device=device,
    save_dir=shared_model_save_dir,
    save_interval=1  # Save model every epoch
)

# Save the fine-tuned shared-private model
final_shared_model_save_path = os.path.join(shared_model_save_dir, 'shared_private_best_model.pt')
torch.save(shared_model.state_dict(), final_shared_model_save_path)
print(f"Shared BERT fine-tuned model saved to {final_shared_model_save_path}")


Some weights of BertModel were not initialized from the model checkpoint at ./models/bert-mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at ./models/bert-mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at ./models/bert-mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initializing and training the shared BERT model...
[INFO] No checkpoint found in './models/bert-classification'. Loading shared BERT using 'from_pretrained'.
[INFO] Shared BERT loaded from directory successfully.


Some weights of BertModel were not initialized from the model checkpoint at ./models/bert-mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO] Shared BERT weights loaded successfully.
[INFO] Private BERT weights loaded successfully.
[INFO] Classifier weights loaded successfully.
Model found at ./models/shared_bert_finetuned/shared_private_best_model.pt. Skipping training and loading the model.
Error: model_state_dict not found in the checkpoint!


ValueError: The saved checkpoint does not contain a valid model_state_dict.

In [8]:
import torch.nn as nn
criterion = nn.CrossEntropyLoss()
test_metrics = sharedprivate_eval_model(shared_model, test_loader, criterion, device, len(test_loader.dataset))
print(
    f"Test Metrics - Accuracy: {test_metrics['accuracy']:.4f}, "
    f"Precision: {test_metrics['precision']:.4f}, "
    f"Recall: {test_metrics['recall']:.4f}, "
    f"F1-Score: {test_metrics['f1_score']:.4f}")

Evaluating batches:   8%|█▊                    | 33/398 [00:33<04:42,  1.29it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x116030c10>
Traceback (most recent call last):
  File "/Users/manav/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/Users/manav/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/manav/miniconda3/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/manav/miniconda3/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/manav/miniconda3/lib/python3.9/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/manav/miniconda3/lib/python3.9/selectors.py", line 416, in sel

KeyboardInterrupt: 

In [2]:
# Cell 8: Train Shared-Private Models for Each Humor Dataset
print("\n--- Training Shared-Private Models for Each Humor Dataset ---")

# Define the directory to save shared-private models
shared_private_models_dir = './models/updated_shared_private'
if not os.path.exists(shared_private_models_dir):
    os.makedirs(shared_private_models_dir)

for dataset_name in humor_types:
    print(f"\n--- Training on {dataset_name} dataset ---")
    # Filter data for the current humor type
    train_subset = train_df[train_df['type'] == dataset_name]
    val_subset = val_df[val_df['type'] == dataset_name]
    test_subset = test_df[test_df['type'] == dataset_name]

    # Create datasets and dataloaders
    train_dataset = HumorDataset(train_subset, TOKENIZER_NAME, MAX_LENGTH)
    val_dataset = HumorDataset(val_subset, TOKENIZER_NAME, MAX_LENGTH)
    test_dataset = HumorDataset(test_subset, TOKENIZER_NAME, MAX_LENGTH)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize a new SharedPrivateModel for the current dataset
    # Reuse the shared BERT and use the same private BERT as trained above
    # Alternatively, you can initialize separate private BERTs if desired

    # Load the best shared-private model from shared_model_save_dir
    model = SharedPrivateModel(
        shared_model_path=shared_model_save_dir,  # Path to the shared BERT fine-tuned model
        private_model_paths=private_model_paths,   # Same private BERT paths
        num_labels=NUM_LABELS
    )
    model = model.to(device)

    # Load the best model state
    best_shared_model_path = os.path.join(shared_model_save_dir, 'shared_private_best_model.pt')
    model.load_state_dict(torch.load(best_shared_model_path))
    print(f"Loaded shared BERT fine-tuned model from {best_shared_model_path}")

    # Define save path for the current model
    sanitized_dataset_name = dataset_name.replace(" ", "_")
    model_save_path = os.path.join(shared_private_models_dir, f'shared_private_model_{sanitized_dataset_name}.pt')

    # Train the model on the specific humor dataset
    print(f"Starting training for {dataset_name}...")
    model = sharedprivate_train_private_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        epochs=EPOCHS_PRIVATE,
        learning_rate=LEARNING_RATE,
        device=device,
        save_dir=shared_private_models_dir,
        save_interval=1  # Save model every epoch
    )

    # Save the trained model
    final_model_save_path = os.path.join(shared_private_models_dir, f'shared_private_model_{sanitized_dataset_name}.pt')
    torch.save(model.state_dict(), final_model_save_path)
    print(f"Model for {dataset_name} training complete and saved to {final_model_save_path}")



--- Training Shared-Private Models for Each Humor Dataset ---

--- Training on puns dataset ---
[INFO] Found checkpoint 'shared_private_best_model.pt' in './models/shared_bert_finetuned'. Loading shared BERT from checkpoint.
[INFO] Shared BERT loaded from checkpoint successfully.
Loaded shared BERT fine-tuned model from ./models/shared_bert_finetuned/shared_private_best_model.pt
Starting training for body punchlines...
Training shared-private model on device: mps

--- Epoch 1/3 ---
Training batches: 100%|███████████████████████| 469/469 [20:06<00:00,  2.57s/it]
Epoch 1 - Loss: 0.0085, Acc: 0.9221
Evaluating batches: 100%|█████████████████████| 100/100 [01:45<00:00,  1.06s/it]
Validation Metrics - Accuracy: 0.9260, Precision: 0.9245, Recall: 0.9260, F1-Score: 0.9252
Saving checkpoint to: ./models/updated_shared_private/shared_private_1.pt

--- Epoch 2/3 ---
Training batches: 100%|███████████████████████| 469/469 [19:50<00:00,  2.54s/it]
Epoch 2 - Loss: 0.0032, Acc: 0.9310
Evaluating ba

In [None]:
import torch

checkpoint_path = './models/shared_bert_finetuned/shared_private_best_model.pt'

# Load and print the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Checkpoint keys:", checkpoint.keys())
