In [10]:
import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", message="resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.", category=FutureWarning)
warnings.filterwarnings("ignore", message="`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.", category=FutureWarning)
warnings.filterwarnings("ignore", message="TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class.", category=UserWarning, module=r"torch\._utils")
warnings.filterwarnings("ignore", message="promote has been superseded by promote_options='default'", category=FutureWarning, module=r"datasets\.table")

# Standard library imports
import argparse
import os
import sys
import signal
import time
from collections import Counter
import traceback
import math
from multiprocessing import Value
from queue import Empty

# PyTorch imports
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer

# Third-party library imports
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import sst
from datasets import load_dataset
import wandb

# Custom utility imports
from utils import (
    setup_environment, 
    prepare_device, 
    fix_random_seeds,
    convert_numeric_to_labels, 
    convert_labels_to_tensor,
    format_time, 
    convert_sst_label, 
    get_activation,
    set_threads,
    signal_handler,
    cleanup_and_exit,
    get_optimizer,
    get_shape_color,
    print_rank_memory_summary,
    tensor_to_numpy,
    print_label_dist,
    get_scheduler,
    parse_dict
)
from datawaza_funcs import eval_model
from torch_ddp_finetune_neural_classifier import TorchDDPNeuralClassifier, SentimentDataset
from colors import *

# Suppress Hugging Face library warnings
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("datasets").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub.repocard").setLevel(logging.ERROR)

from torch.utils.data import Dataset

from multiprocessing import Queue, Pipe, set_start_method

def load_label_dicts(label_template):
    if label_template == 'neg_neu_pos':
        label_dict = {'negative': 0, 'neutral': 1, 'positive': 2}
        numeric_dict = {0: 'negative', 1: 'neutral', 2: 'positive'}
    elif label_template == 'bin_neu':
        label_dict = {'non-neutral': 0, 'neutral': 1}
        numeric_dict = {0: 'non-neutral', 1: 'neutral'}
    elif label_template == 'bin_pos':
        label_dict = {'non-positive': 0, 'positive': 1}
        numeric_dict = {0: 'non-positive', 1: 'positive'}
    elif label_template == 'bin_neg':
        label_dict = {'non-negative': 0, 'negative': 1}
        numeric_dict = {0: 'non-negative', 1: 'negative'}
    else:
        raise ValueError(f"Unknown label template: {label_template}. Options are: 'neg_neu_pos', 'bin_neu', 'bin_pos', 'bin_neg'")
    
    return label_dict, numeric_dict

def save_data_archive(X_train, X_val, X_test, y_train, y_val, y_test, X_test_sent, world_size, device_type, data_dir):
    # Create directory if it doesn't exist
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    # Create filename with appropriate suffix and timestamp
    suffix = f'_{world_size}_gpu' if device_type == 'cuda' else '_1_cpu'
    timestamp = time.strftime('%Y%m%d-%H%M%S')
    filename = f'data{suffix}_{timestamp}.npz'
    filepath = os.path.join(data_dir, filename)

    # Convert tensors to NumPy arrays if necessary
    X_train = tensor_to_numpy(X_train)
    X_val = tensor_to_numpy(X_val) if X_val is not None else None
    X_test = tensor_to_numpy(X_test)
    y_train = tensor_to_numpy(y_train)
    y_val = tensor_to_numpy(y_val) if y_val is not None else None
    y_test = tensor_to_numpy(y_test)

    # Save data to archive file
    if X_val is not None:
        np.savez_compressed(filepath, X_train=X_train, X_val=X_val, X_test=X_test, y_train=y_train, y_val=y_val, y_test=y_test, X_test_sent=X_test_sent)
    else:
        np.savez_compressed(filepath, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test, X_test_sent=X_test_sent)
    print(f"\nData saved to: {filepath}")

def load_data_archive(data_file, device, rank, sample_percent=None):
    load_archive_start = time.time()
    
    # Check if the archive file path is provided
    if data_file is None:
        raise ValueError(f"{red}No archive file provided to load data from{reset}")
    
    # Check if the archive file exists
    if not os.path.exists(data_file):
        raise FileNotFoundError(f"{red}Archive file not found: {data_file}{reset}")
    
    # Attempt to load the data from the archive file
    try:
        print(f"\n{sky_blue}Loading archived data from: {data_file}...{reset}") if rank == 0 else None
        with np.load(data_file, allow_pickle=True) as data:
            X_train = data['X_train']
            X_val = data['X_val'] if 'X_val' in data else None
            X_test = data['X_test']
            y_train = data['y_train']
            y_val = data['y_val'] if 'y_val' in data else None
            y_test = data['y_test']
            X_test_sent = data['X_test_sent']
        
        # Sample data if sample_percent is provided
        if sample_percent is not None:
            print(f"Sampling {sample_percent:.0%} of data...") if rank == 0 else None
            num_train_samples = int(len(X_train) * sample_percent)
            num_val_samples = int(len(X_val) * sample_percent) if X_val is not None else None
            num_test_samples = int(len(X_test) * sample_percent)
            
            # Create a permutation of indices
            train_indices = np.random.permutation(len(X_train))[:num_train_samples]
            val_indices = np.random.permutation(len(X_val))[:num_val_samples] if X_val is not None else None
            test_indices = np.random.permutation(len(X_test))[:num_test_samples]
            
            # Sample the data
            X_train = X_train[train_indices]
            y_train = y_train[train_indices]
            X_val = X_val[val_indices] if X_val is not None else None
            y_val = y_val[val_indices] if y_val is not None else None
            X_test = X_test[test_indices]
            y_test = y_test[test_indices]
            X_test_sent = X_test_sent[test_indices]
            
            if X_val is not None:
                print(f"Sampled Train size: {len(X_train)}, Sampled Validation size: {len(X_val)}, Sampled Evaluation size: {len(X_test)}") if rank == 0 else None
            else:
                print(f"Sampled Train size: {len(X_train)}, Sampled Evaluation size: {len(X_test)}") if rank == 0 else None
        
        if rank == 0:
            # Print a summary of the loaded data
            print(f"X Train shape: {list(X_train.shape)}, y Train shape: {list(y_train.shape)}")
            print(f"X Validation shape: {list(X_val.shape)}, y Validation shape: {list(y_val.shape)}") if X_val is not None else None
            print(f"X Test shape: {list(X_test.shape)}, y Dev shape: {list(y_test.shape)}")
            print(f"X Test Sentences shape: {list(X_test_sent.shape)}")
            # Print label distributions
            print("Train label distribution:")
            print_label_dist(y_train)
            if X_val is not None:
                print("Validation label distribution:")
                print_label_dist(y_val)
            print("Test label distribution:")
            print_label_dist(y_test)
        print(f"Archived data loaded ({time.time() - load_archive_start:.2f}s)") if rank == 0 else None
    except Exception as e:
        raise RuntimeError(f"Failed to load data from archive file {data_file}: {str(e)}")
        
    return X_train, X_val, X_test, y_train, y_val, y_test, X_test_sent

def initialize_bert_model(weights_name, device, rank, debug):
    model_init_start = time.time()
    print(f"\nInitializing '{weights_name}' tokenizer and model...") if rank == 0 else None
    bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
    bert_model = BertModel.from_pretrained(weights_name).to(device)
    # if device.type == 'cuda':
    #     bert_model = DDP(bert_model, device_ids=[rank], output_device=rank, static_graph=True)
    # else:
    #     bert_model = DDP(bert_model, device_ids=None, output_device=None, static_graph=True)
    dist.barrier()
    if rank == 0:
        if debug:
            print(f"Bert Tokenizer:\n{bert_tokenizer}")
            print(f"Bert Model:\n{bert_model}")
        print(f"Tokenizer and model initialized ({format_time(time.time() - model_init_start)})")
    return bert_tokenizer, bert_model

def initialize_transformer_model(weights_name, device, rank, debug):
    model_init_start = time.time()
    print(f"\n{sky_blue}Initializing '{weights_name}' tokenizer and model...{reset}") if rank == 0 else None
    
    max_retries = 3
    retry_delay = 5
    
    for attempt in range(max_retries):
        try:
            # Only rank 0 checks and downloads files
            if rank == 0:
                # Try loading with local_files_only first to check if files exist
                try:
                    print(f"Checking for local files...")
                    _ = AutoTokenizer.from_pretrained(weights_name, local_files_only=True)
                    _ = AutoModel.from_pretrained(weights_name, local_files_only=True)
                    print(f"Found all files in cache, skipping download")
                except Exception as e:
                    print(f"Some files not found locally, downloading...")
                    # Download tokenizer files
                    tokenizer = AutoTokenizer.from_pretrained(weights_name, local_files_only=False)
                    # Download model files
                    _ = AutoModel.from_pretrained(weights_name, local_files_only=False)
                    print(f"Download complete")
            
            # Wait for rank 0 to finish checking/downloading
            dist.barrier()
            
            # Now all ranks can load from local files
            if rank == 0:
                print(f"All ranks loading tokenizer from local files...")
            tokenizer = AutoTokenizer.from_pretrained(weights_name, local_files_only=True)
            
            if rank == 0:
                print(f"All ranks loading model from local files...")
            model = AutoModel.from_pretrained(weights_name, local_files_only=True).to(device)
            
            # Final sync point
            dist.barrier()
            
            if rank == 0:
                if debug:
                    print(f"Tokenizer:\n{tokenizer}")
                    print(f"Model:\n{model}")
                print(f"Tokenizer and model initialized ({format_time(time.time() - model_init_start)})")
            return tokenizer, model
            
        except Exception as e:
            if attempt < max_retries - 1:
                if rank == 0:
                    print(f"\n{yellow}Attempt {attempt + 1} failed. Retrying in {retry_delay} seconds...{reset}")
                    print(f"Error: {str(e)}")
                time.sleep(retry_delay)
                retry_delay *= 2  # Exponential backoff
            else:
                if rank == 0:
                    print(f"\n{red}Failed to initialize tokenizer/model after {max_retries} attempts{reset}")
                raise e

    raise RuntimeError("Failed to initialize transformer model and tokenizer")

def load_data(dataset, eval_dataset, sample_percent, eval_split, use_val_split, val_percent,
              world_size, rank, debug):
    data_load_start = time.time()

    # Function to get a subset of data based on split name
    def get_split(data, split):
            split = 'validation' if split == 'dev' else split
            if split in ['train', 'validation', 'test']:
                data_split = data[split].to_pandas()
            else:
                raise ValueError(f"Unknown split: {split}")
            return data_split
    
    # Function to load data from Hugging Face or local based on ID and split name
    def get_data(id, split, purpose, rank, debug):
        # Identify the dataset and path from the ID
        dataset_source = 'Hugging Face'
        dataset_subset = None
        dataset_url = None
        if id == 'sst_local':
            dataset_name = 'Stanford Sentiment Treebank (SST)'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'sentiment')
        elif id == 'sst':
            dataset_name = 'Stanford Sentiment Treebank (SST)'
            dataset_url = 'https://huggingface.co/datasets/gimmaru/SetFit-sst5'
            dataset_path = 'SetFit/sst5'
        elif id in ['dynasent', 'dynasent_r1']:
            dataset_name = 'DynaSent Round 1'
            dataset_url = 'https://huggingface.co/datasets/dynabench/dynasent'
            dataset_path = 'dynabench/dynasent'
            dataset_subset = 'dynabench.dynasent.r1.all'
        elif id == 'dynasent_r2':
            dataset_name = 'DynaSent Round 2'
            dataset_url = 'https://huggingface.co/datasets/dynabench/dynasent'
            dataset_path = 'dynabench/dynasent'
            dataset_subset = 'dynabench.dynasent.r2.all'
        elif id == 'mteb_tweet':
            dataset_name = 'MTEB Tweet Sentiment Extraction'
            dataset_url = 'https://huggingface.co/datasets/mteb/tweet_sentiment_extraction'
            dataset_path = 'mteb/tweet_sentiment_extraction'
        elif id == 'merged_local':
            dataset_name = 'Merged DynaSent Round 1, Round 2 and SST'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'merged')
        elif id == 'merged_neutral':
            dataset_name = 'Merged DynaSent Round 1, Round 2 and SST: Neutral Only'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'merged')
        elif id == 'merged_positive':
            dataset_name = 'Merged DynaSent Round 1, Round 2 and SST: Positive Only'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'merged')
        elif id == 'merged_negative':
            dataset_name = 'Merged DynaSent Round 1, Round 2 and SST: Negative Only'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'merged')
        elif id == 'merged_balanced':
            dataset_name = 'Merged DynaSent Round 1, Round 2 and SST: Balanced'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'merged')
        else:
            raise ValueError(f"Unknown dataset: {id}")
        print(f"{purpose} Data: {dataset_name} from {dataset_source}: '{dataset_path}'") if rank == 0 else None
        #print(f"Dataset URL: {dataset_url}") if dataset_url is not None and rank == 0 else None

        # Load the dataset, do any pre-processing, and select appropriate split
        if id == 'sst_local':
            if split == 'train':
                src = os.path.join(dataset_path, 'sst3-train.csv')
                data_split = sst.sentiment_reader(src, include_subtrees=False, dedup=False)
            elif split in ['dev', 'validation']:
                src = os.path.join(dataset_path, 'sst3-dev.csv')
                data_split = sst.sentiment_reader(src, include_subtrees=False, dedup=False)
            elif split in ['test', 'test-labeled']:
                src = os.path.join(dataset_path, 'sst3-test-labeled.csv')
                data_split = sst.sentiment_reader(src, include_subtrees=False, dedup=False)
            elif split == 'test-unlabeled':
                src = os.path.join(dataset_path, 'sst3-test-unlabeled.csv')
                data_split = sst.sentiment_reader(src, include_subtrees=False, dedup=False)
            else:
                raise ValueError(f"Unknown split: {split}")
        elif id == 'sst':
            data = load_dataset(dataset_path)
            data = data.rename_column('label', 'label_orig') 
            for split_name in ('train', 'validation', 'test'):
                dis = [convert_sst_label(s) for s in data[split_name]['label_text']]
                data[split_name] = data[split_name].add_column('label', dis)
                data[split_name] = data[split_name].add_column('sentence', data[split_name]['text'])
            data_split = get_split(data, split)
        elif id in ['dynasent', 'dynasent_r1']:
            data = load_dataset(dataset_path, dataset_subset)
            data = data.rename_column('gold_label', 'label')
            data_split = get_split(data, split)
        elif id == 'dynasent_r2':
            data = load_dataset(dataset_path, dataset_subset)
            data = data.rename_column('gold_label', 'label')
            data_split = get_split(data, split)
        elif id == 'mteb_tweet':
            data = load_dataset(dataset_path)
            data = data.rename_column('label', 'label_orig') 
            data = data.rename_column('label_text', 'label')
            data = data.rename_column('text', 'sentence')
            split = 'test' if split == 'dev' else split
            data_split = get_split(data, split)
        elif id == 'merged_local':
            if split == 'train':
                data_split = pd.read_csv(os.path.join(dataset_path, 'train_all.csv'), index_col=None)
            elif split in ['dev', 'validation']:
                data_split = pd.read_csv(os.path.join(dataset_path, 'val_all.csv'), index_col=None)
            elif split == 'test':
                data_split = pd.read_csv(os.path.join(dataset_path, 'test_all.csv'), index_col=None)
        elif id == 'merged_balanced':
            if split == 'train':
                data_split = pd.read_csv(os.path.join(dataset_path, 'train_balanced.csv'), index_col=None)
            elif split in ['dev', 'validation']:
                data_split = pd.read_csv(os.path.join(dataset_path, 'val_all.csv'), index_col=None)
            elif split == 'test':
                data_split = pd.read_csv(os.path.join(dataset_path, 'test_all.csv'), index_col=None)
        elif id == 'merged_neutral':
            if split == 'train':
                data_split = pd.read_csv(os.path.join(dataset_path, 'train_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'neutral_label': 'label'})
            elif split in ['dev', 'validation']:
                data_split = pd.read_csv(os.path.join(dataset_path, 'val_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'neutral_label': 'label'})
            elif split == 'test':
                data_split = pd.read_csv(os.path.join(dataset_path, 'test_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'neutral_label': 'label'})
        elif id == 'merged_positive':
            if split == 'train':
                data_split = pd.read_csv(os.path.join(dataset_path, 'train_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'positive_label': 'label'})
            elif split in ['dev', 'validation']:
                data_split = pd.read_csv(os.path.join(dataset_path, 'val_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'positive_label': 'label'})
            elif split == 'test':
                data_split = pd.read_csv(os.path.join(dataset_path, 'test_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'positive_label': 'label'})
        elif id == 'merged_negative':
            if split == 'train':
                data_split = pd.read_csv(os.path.join(dataset_path, 'train_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'negative_label': 'label'})
            elif split in ['dev', 'validation']:
                data_split = pd.read_csv(os.path.join(dataset_path, 'val_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'negative_label': 'label'})
            elif split == 'test':
                data_split = pd.read_csv(os.path.join(dataset_path, 'test_all_binary.csv'), index_col=None)
                data_split = data_split.rename(columns={'label': 'label_orig'})
                data_split = data_split.rename(columns={'negative_label': 'label'})
        else:
            raise ValueError(f"Unknown dataset: {id}")

        return data_split

    if rank == 0:
        print(f"\n{sky_blue}Loading data...{reset}")
        if eval_dataset is not None:
            print(f"Using different datasets for training and evaluation")
        else:
            eval_dataset = dataset
            print(f"Using the same dataset for training and evaluation")
        print(f"Splits:")
        print(f"- Train: Using {dataset} 'train' split")
        if use_val_split:
            print(f"- Validation: Using {dataset} 'validation' split")
        else:
            print(f"- Validation: Using {val_percent} of {dataset} 'train' split")
        print(f"- Evaluation: Using {eval_dataset} '{eval_split}' split")

        train = get_data(dataset, 'train', 'Train', rank, debug)
        if use_val_split:
            validation = get_data(dataset, 'validation', 'Validation', rank, debug)
        else:
            validation = None
        test = get_data(eval_dataset, eval_split, 'Evaluation', rank, debug)
        
        print(f"Train size: {len(train)}")
        print(f"Validation size: {len(validation)}") if validation is not None else None
        print(f"Evaluation size: {len(test)}")

        if sample_percent is not None:
            print(f"Sampling {sample_percent:.0%} of data...")
            train = train.sample(frac=sample_percent)
            if validation is not None:
                validation = validation.sample(frac=sample_percent)
            test = test.sample(frac=sample_percent)
            print(f"Sampled Train size: {len(train)}")
            print(f"Sampled Validation size: {len(validation)}") if validation is not None else None
            print(f"Sampled Evaluation size: {len(test)}")

    else:
        train = None
        validation = None
        test = None

    # Broadcast the data to all ranks
    if world_size > 1:
        object_list = [train, validation, test]
        dist.broadcast_object_list(object_list, src=0)
        train, validation, test = object_list

        dist.barrier()
        print(f"Data broadcasted to all ranks") if rank == 0 and debug else None
        print(f"Rank {rank}: Train size: {len(train)}, Validation size: {len(validation) if validation is not None else None}, Evaluation size: {len(test)}") if debug else None

    if rank == 0:
        print("Train label distribution:")
        print_label_dist(train)
        if validation is not None:
            print("Validation label distribution:")
            print_label_dist(validation)
        print("Evaluation label distribution:")
        print_label_dist(test)
        print(f"Data loaded ({format_time(time.time() - data_load_start)})")
    dist.barrier()
        
    return train, validation, test

def process_data(bert_tokenizer, bert_model, pooling, world_size, train, validation, test, device, batch_size, rank, debug, save_archive,
                 save_dir, num_workers, prefetch, empty_cache, finetune_bert, freeze_bert, chunk_size=None):
    data_process_start = time.time()

    print(f"\n{sky_blue}Processing data...{reset}") if rank == 0 else None
    print(f"(Batch size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Fine Tune BERT: {finetune_bert}, Chunk size: {chunk_size})...") if rank == 0 else None
    print(f"Extracting sentences and labels...") if rank == 0 else None
    
    # Extract y labels
    y_train = train.label.values
    y_val = validation.label.values if validation is not None else None
    y_test = test.label.values
    
    # Extract X sentences
    X_train_sent = train.sentence.values
    X_val_sent = validation.sentence.values if validation is not None else None
    X_test_sent = test.sentence.values

    if rank == 0:
        # Generate random indices
        train_indices = np.random.choice(len(X_train_sent), 3, replace=False)
        val_indices = np.random.choice(len(X_val_sent), 3, replace=False) if validation is not None else None
        test_indices = np.random.choice(len(X_test_sent), 3, replace=False)
        
        # Collect sample sentences
        train_samples = []
        val_samples = []
        test_samples = []
        for i in train_indices:
            train_samples.append((f'Train[{i}]: ', X_train_sent[i], f' - {y_train[i].upper()}'))
        if validation is not None:
            for i in val_indices:
                val_samples.append((f'Validation[{i}]: ', X_val_sent[i], f' - {y_val[i].upper()}'))
        for i in test_indices:
            test_samples.append((f'Evaluation[{i}]: ', X_test_sent[i], f' - {y_test[i].upper()}'))
    else:
        train_samples = None
        val_samples = None
        test_samples = None
    
    # Process X sentences (tokenize and encode with BERT) if we're not fine-tuning BERT
    if finetune_bert:
        # For fine-tuning, we just return the sentences
        X_train = X_train_sent
        X_val = X_val_sent
        X_test = X_test_sent
    else:
        # Process X sentences (tokenize and encode with BERT) for non-fine-tuning workflow
        X_train = process_data_chunks(X_train_sent, bert_tokenizer, bert_model, pooling, world_size, device, batch_size, 
                                      train_samples, rank, debug, split='Train', num_workers=num_workers, prefetch=prefetch,
                                      empty_cache=empty_cache, chunk_size=chunk_size)
        X_val = process_data_chunks(X_val_sent, bert_tokenizer, bert_model, pooling, world_size, device, batch_size, 
                                    val_samples, rank, debug, split='Validation', num_workers=num_workers, prefetch=prefetch,
                                    empty_cache=empty_cache, chunk_size=chunk_size) if validation is not None else None
        X_test = process_data_chunks(X_test_sent, bert_tokenizer, bert_model, pooling, world_size, device, batch_size, 
                                    test_samples, rank, debug, split='Evaluation', num_workers=num_workers, prefetch=prefetch,
                                    empty_cache=empty_cache, chunk_size=chunk_size)
    
    # Data integrity check, make sure the sizes are consistent across ranks
    if not finetune_bert and device.type == 'cuda' and world_size > 1:
        # Gather sizes from all ranks
        train_sizes = [torch.tensor(X_train.shape[0], device=device) for _ in range(world_size)]
        val_sizes = [torch.tensor(X_val.shape[0], device=device) for _ in range(world_size)] if validation is not None else None
        test_sizes = [torch.tensor(X_test.shape[0], device=device) for _ in range(world_size)]
        
        dist.all_gather(train_sizes, train_sizes[rank])
        dist.all_gather(val_sizes, val_sizes[rank]) if validation is not None else None
        dist.all_gather(test_sizes, test_sizes[rank])

        if rank == 0:
            # Convert to CPU for easier handling
            train_sizes = [size.cpu().item() for size in train_sizes]
            val_sizes = [size.cpu().item() for size in val_sizes] if validation is not None else None
            test_sizes = [size.cpu().item() for size in test_sizes]

            if debug:
                print("\nDataset size summary:")
                print(f"Train sizes across ranks: {train_sizes}")
                print(f"Validation sizes across ranks: {val_sizes}") if validation is not None else None
                print(f"Test sizes across ranks: {test_sizes}")
                
                if len(set(train_sizes)) > 1 or len(set(test_sizes)) > 1 or (validation is not None and len(set(val_sizes)) > 1):
                    print(f"{red}WARNING: Mismatch in dataset sizes across ranks!{red}")
                    print(f"Train size mismatch: {max(train_sizes) - min(train_sizes)}")
                    print(f"Validation size mismatch: {max(val_sizes) - min(val_sizes)}") if validation is not None else None
                    print(f"Test size mismatch: {max(test_sizes) - min(test_sizes)}")
                else:
                    print("All ranks have consistent dataset sizes.")
                
                print(f"Total train samples: {sum(train_sizes)}")
                print(f"Total validation samples: {sum(val_sizes)}") if validation is not None else None
                print(f"Total test samples: {sum(test_sizes)}")

            # Check for significant mismatch and raise error if necessary
            max_mismatch = max(max(train_sizes) - min(train_sizes), max(test_sizes) - min(test_sizes))
            if max_mismatch > world_size:  # Allow for small mismatches due to uneven division
                raise ValueError(f"{red}Significant mismatch in dataset sizes across ranks. Max difference: {max_mismatch}{reset}")

    if save_archive and rank == 0:
        save_data_archive(X_train, X_val, X_test, y_train, y_val, y_test, X_test_sent, world_size, device.type, save_dir)

    dist.barrier()
    if rank == 0:
        if validation is not None:
            print(f"X Train shape: {list(np.shape(X_train))}, X Validation shape: {list(np.shape(X_val))}, X Test shape: {list(np.shape(X_test))}")
            print(f"y Train shape: {list(np.shape(y_train))}, y Validation shape: {list(np.shape(y_val))}, y Test shape: {list(np.shape(y_test))}")
        else:
            print(f"X Train shape: {list(np.shape(X_train))}, X Test shape: {list(np.shape(X_test))}")
            print(f"y Train shape: {list(np.shape(y_train))}, y Test shape: {list(np.shape(y_test))}")
        print(f"Data processed ({format_time(time.time() - data_process_start)})")
    
    return X_train, X_val, X_test, y_train, y_val, y_test, X_test_sent

def process_data_chunks(texts, tokenizer, model, pooling, world_size, device, batch_size, sample_texts, rank, debug, split,
                        num_workers, prefetch, empty_cache, chunk_size=None):
    if chunk_size is None or chunk_size >= len(texts):
        return bert_phi(texts, tokenizer, model, pooling, world_size, device, batch_size, sample_texts, rank, debug, split,
                        num_workers, prefetch, empty_cache)
    
    print(f"\n{sky_blue}Processing {split} data in chunks of size {chunk_size}...{reset}") if rank == 0 else None
    
    all_embeddings = []
    num_chunks = math.ceil(len(texts) / chunk_size)
    
    for i in range(num_chunks):
        chunk_start = i * chunk_size
        chunk_end = min((i + 1) * chunk_size, len(texts))
        chunk_texts = texts[chunk_start:chunk_end]
        
        print(f"\n{sky_blue}Processing chunk {i+1}/{num_chunks} (samples {chunk_start} to {chunk_end-1})...{reset}") if rank == 0 else None
        
        # Only pass sample_texts for the first chunk
        current_sample_texts = sample_texts if i == 0 else None
        
        chunk_embeddings = bert_phi(chunk_texts, tokenizer, model, pooling, world_size, device, batch_size, 
                                    current_sample_texts, rank, debug, f"{split}_chunk_{i+1}", 
                                    num_workers, prefetch, empty_cache, i+1, num_chunks)
        
        # Move embeddings to CPU and convert to numpy to save GPU memory
        chunk_embeddings = chunk_embeddings.cpu().numpy()
        all_embeddings.append(chunk_embeddings)
        
        # Clear CUDA cache
        if empty_cache and device.type == 'cuda':
            torch.cuda.empty_cache()
        
        dist.barrier()
    
    # Concatenate all chunk embeddings
    final_embeddings = np.concatenate(all_embeddings, axis=0)
    
    print(f"Finished processing all chunks for {split} data.") if rank == 0 else None
    
    return final_embeddings


def bert_phi(texts, tokenizer, model, pooling, world_size, device, batch_size, sample_texts, rank, debug, split, num_workers, prefetch, empty_cache,
             chunk_id=None, num_chunks=None):
    encoding_start = time.time()
    total_texts = len(texts)
    embeddings = []

    # Ensure texts is a list
    if isinstance(texts, np.ndarray):
        texts = texts.tolist()
    elif not isinstance(texts, (list, tuple)):
        raise TypeError(f"{red}texts must be a list, tuple, or numpy array. Got {type(texts)}{reset}")
    
    def tokenize(texts, tokenizer, device):
        # Convert NumPy array to list if necessary
        if isinstance(texts, np.ndarray):
            texts = texts.tolist()

        encoded = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)
        return input_ids, attention_mask

    def pool(last_hidden_state, attention_mask, pooling):
        if pooling == 'cls':
            return last_hidden_state[:, 0, :]
        elif pooling == 'mean':
            return (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
        elif pooling == 'max':
            return torch.max(last_hidden_state * attention_mask.unsqueeze(-1), dim=1)[0]
        else:
            raise ValueError(f"{red}Unknown pooling method: {pooling}{reset}")

    # Process and display sample texts first
    def display_sample_texts(sample_texts):
        if sample_texts is None:
            return
        print(f"\n{sky_blue}Displaying samples from {split.capitalize()} data:{reset}")
        for text in sample_texts:
            # Tokenize the text and get the tokens
            tokens = tokenizer.tokenize(text[1])
            print(f"{text[0]}{text[1]}{text[2]}")
            print(f"Tokens: {tokens}")
            
            # Encode the text (including special tokens) and get embeddings
            input_ids, attention_mask = tokenize([text[1]], tokenizer, device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)
            
            embedding = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            print(f"Embedding: {embedding[0, :6].cpu().numpy()} ...")
            print()

            if device.type == 'cuda':
                del input_ids, attention_mask, outputs, embedding
                torch.cuda.empty_cache()

    # Use DDP to distribute the encoding process across multiple GPUs
    if device.type == 'cuda' and world_size > 1: 
        if rank == 0:

            # Display sample texts
            display_sample_texts(sample_texts)

            print(f"\n{sky_blue}Encoding {split.capitalize()} data of {total_texts} texts distributed across {world_size} GPUs...{reset}")
            print(f"Batch Size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Empty Cache: {empty_cache}")

        dist.barrier()
        # Calculate the number of texts that make the dataset evenly divisible by world_size
        texts_per_rank = math.ceil(total_texts / world_size)
        padded_total = texts_per_rank * world_size
        
        if padded_total > total_texts:
            print(f"Padding {split.capitalize()} data to {padded_total} texts for even distribution across {world_size} ranks...") if rank == 0 else None
        
        # Calculate number of padding texts needed
        padding_texts = padded_total - total_texts

        # Create padding texts using [PAD] token
        pad_text = tokenizer.pad_token * 10  # Arbitrary length, will be truncated if too long
        texts_with_padding = list(texts) + [pad_text] * padding_texts  # Convert texts to list and then concatenate

        # Distribute texts evenly across ranks
        start_idx = rank * texts_per_rank
        end_idx = start_idx + texts_per_rank
        local_texts = texts_with_padding[start_idx:end_idx]
        local_batch_count = len(local_texts) // batch_size + 1

        batch_count = len(texts) // batch_size + 1
        total_batches = batch_count

        if rank == 0:
            print(f"Texts per rank: {texts_per_rank}, Total batches: {total_batches}")
        
        dist.barrier()
        print(f"Rank {rank}: Processing {len(local_texts)} texts (indices {start_idx} to {end_idx-1}) in {local_batch_count} batches...")
        
        dist.barrier()
        model.eval()

        for i in range(0, len(local_texts), batch_size):
            batch_start = time.time()
            batch_texts = local_texts[i:i + batch_size]
            input_ids, attention_mask = tokenize(batch_texts, tokenizer, device)

            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)

            batch_embeddings = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            embeddings.append(batch_embeddings)

            batch_shape = list(batch_embeddings.shape)

            if empty_cache:
                # Delete the unused objects
                del outputs, input_ids, attention_mask, batch_embeddings
                # Empty CUDA cache
                torch.cuda.empty_cache()

            shape_color = get_shape_color(batch_size, batch_shape)
            if chunk_id is not None:
                print(f"Rank {bright_white}{bold}{rank}{reset}: Chunk {purple}{bold}{chunk_id}{reset} / {num_chunks}, Batch {sky_blue}{bold}{(i // batch_size) + 1:2d}{reset} / {local_batch_count}, Shape: {shape_color}{bold}{batch_shape}{reset}, Time: {format_time(time.time() - batch_start)}")
            else:
                print(f"Rank {bright_white}{bold}{rank}{reset}: Batch {sky_blue}{bold}{(i // batch_size) + 1:2d}{reset} / {local_batch_count}, Shape: {shape_color}{bold}{batch_shape}{reset}, Time: {format_time(time.time() - batch_start)}")

            if rank == 0:
                if (i // batch_size) % 5 == 0:
                    print_rank_memory_summary(world_size, rank, all_local=True, verbose=False)

        local_embeddings = torch.cat(embeddings, dim=0)

        #dist.barrier()

        if world_size > 1:
            gathered_embeddings = [torch.zeros_like(local_embeddings) for _ in range(world_size)]
            dist.all_gather(gathered_embeddings, local_embeddings)
            all_embeddings = torch.cat(gathered_embeddings, dim=0)

            if rank == 0:  # Only one process needs to do this check
                total_embeddings = all_embeddings.shape[0]
                padding_embeddings = total_embeddings - total_texts
                
                print(f"Total embeddings: {total_embeddings}")
                print(f"Original texts: {total_texts}")
                print(f"Expected padding: {padding_texts}")
                print(f"Actual padding: {padding_embeddings}")
                
                if padding_embeddings != padding_texts:
                    print(f"{red}WARNING: Mismatch in padding count!{reset}")

                if padding_embeddings > 0:
                    padding_embeds = all_embeddings[-padding_embeddings:]
                    
                    if padding_embeds.shape[0] > 1:  # Ensure there are at least 2 embeddings to compare
                        max_diff = torch.max(torch.pdist(padding_embeds))
                        print(f"Maximum difference between padding embeddings: {max_diff}")
                        
                        if max_diff > 1e-6:
                            print(f"{red}WARNING: Padding embeddings are not similar.{reset}")
                        else:
                            print("Padding embeddings verified as very similar.")
                    else:
                        print(f"{yellow}Not enough padding embeddings to calculate differences (found {padding_embeds.shape[0]}).{reset}")

            # Now slice off the padding
            final_embeddings = all_embeddings[:total_texts]
        else:
            final_embeddings = local_embeddings

        dist.barrier()
        if rank == 0:
            print(f"Final embeddings shape: {list(final_embeddings.shape)}") if debug else None
            print(f"Encoding completed ({format_time(time.time() - encoding_start)})")


    else:
        # Take a more straightforward approach for CPU or single GPU
        if rank == 0:
            device_string = 'GPU' if device.type == 'cuda' else 'CPU'
            print(f"\n{sky_blue}Encoding {split.capitalize()} data of {total_texts} texts on a single {device_string}...{reset}")
            print(f"Batch Size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Empty Cache: {empty_cache}")

            # Display sample texts
            display_sample_texts(sample_texts)

        total_batches = len(texts) // batch_size + 1

        for i in range(0, len(texts), batch_size):
            batch_start = time.time()
            batch_texts = texts[i:i + batch_size]
            input_ids, attention_mask = tokenize(batch_texts, tokenizer, device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)

            # Perform only the selected pooling strategy for all batches
            batch_embeddings = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            embeddings.append(batch_embeddings)
            final_embeddings = torch.cat(embeddings, dim=0)
            print(f"Batch {bright_white}{bold}{(i // batch_size) + 1:2d}{reset} / {total_batches}, Shape: {list(batch_embeddings.shape)}, Time: {format_time(time.time() - batch_start)}")
    
            if empty_cache and device.type == 'cuda':
                # Delete the unused objects
                del outputs, input_ids, attention_mask, batch_embeddings
                # Empty CUDA cache
                torch.cuda.empty_cache()

    return final_embeddings


def initialize_classifier(bert_model, bert_tokenizer, finetune_bert, finetune_layers, num_layers, hidden_dim, batch_size,
                          epochs, lr, lr_decay, early_stop, hidden_activation, n_iter_no_change, tol, rank, world_size, device, debug,
                          checkpoint_dir, checkpoint_interval, resume_from_checkpoint, filename=None, use_saved_params=True,
                          optimizer_name=None, use_zero=True, scheduler_name=None, l2_strength=0.0, pooling='cls',
                          target_score=None, interactive=False, response_pipe=None, accumulation_steps=1, max_grad_norm=None,
                          freeze_bert=False, dropout_rate=0.0, show_progress=False, advance_epochs=1, wandb_run=None, val_percent=0.1,
                          random_seed=42, label_dict=None, optimizer_kwargs={}, scheduler_kwargs={}):
    class_init_start = time.time()
    print(f"\n{sky_blue}Initializing DDP Neural Classifier...{reset}") if rank == 0 else None
    hidden_activation = get_activation(hidden_activation, hidden_dim)
    optimizer_class = get_optimizer(optimizer_name, use_zero, device, rank, world_size)
    scheduler_class = get_scheduler(scheduler_name, device, rank, world_size)
    print(f"Layers: {num_layers}, Hidden Dim: {hidden_dim}, Hidden Act: {hidden_activation.__class__.__name__}, Dropout: {dropout_rate}, Optimizer: {optimizer_class.__name__}, L2 Strength: {l2_strength}, Pooling: {pooling.upper()}, Accumulation Steps: {accumulation_steps}, Max Grad Norm: {max_grad_norm}") if rank == 0 else None
    print(f"Batch Size: {batch_size}, Max Epochs: {epochs}, LR: {lr}, Early Stop: {early_stop}, Fine-tune BERT: {finetune_bert}, Fine-tune Layers: {finetune_layers}, Freeze BERT: {freeze_bert}, Target Score: {target_score}, Interactive: {interactive}") if rank == 0 else None
    
    classifier = TorchDDPNeuralClassifier(
        bert_model=bert_model,
        bert_tokenizer=bert_tokenizer,
        finetune_bert=finetune_bert,
        finetune_layers=finetune_layers,
        pooling=pooling,
        num_layers=num_layers,
        early_stopping=early_stop,
        hidden_dim=hidden_dim,
        hidden_activation=hidden_activation,
        batch_size=batch_size,
        max_iter=epochs,
        n_iter_no_change=n_iter_no_change,
        tol=tol,
        eta=lr,
        lr_decay=lr_decay,
        rank=rank,
        world_size=world_size,
        debug=debug,
        checkpoint_dir=checkpoint_dir,
        checkpoint_interval=checkpoint_interval,
        resume_from_checkpoint=resume_from_checkpoint,
        device=device,
        optimizer_class=optimizer_class,
        use_zero=use_zero,
        scheduler_class=scheduler_class,
        target_score=target_score,
        interactive=interactive,
        response_pipe=response_pipe,
        gradient_accumulation_steps=accumulation_steps,
        max_grad_norm=max_grad_norm,
        freeze_bert=freeze_bert,
        dropout_rate=dropout_rate,
        l2_strength=l2_strength,
        show_progress=show_progress,
        advance_epochs=advance_epochs,
        wandb_run=wandb_run,
        validation_fraction=val_percent,
        random_seed=random_seed,
        label_dict=label_dict,
        optimizer_kwargs=optimizer_kwargs,
        scheduler_kwargs=scheduler_kwargs
    )

    if filename is not None:
        print(f"Loading model from: {checkpoint_dir}/{filename}...") if rank == 0 else None
        start_epoch, model_state_dict, optimizer_state_dict = classifier.load_model(directory=checkpoint_dir, filename=filename, pattern=None, use_saved_params=use_saved_params, rank=rank, debug=debug)
    elif resume_from_checkpoint:
        print("Resuming training from the latest checkpoint...") if rank == 0 else None
        start_epoch, model_state_dict, optimizer_state_dict = classifier.load_model(directory=checkpoint_dir, filename=None, pattern='checkpoint_epoch', use_saved_params=use_saved_params, rank=rank, debug=debug)
    else:
        start_epoch = 1
        model_state_dict = None
        optimizer_state_dict = None

    dist.barrier()
    if rank == 0:
        print(classifier) if debug else None
        print(f"Classifier initialized ({format_time(time.time() - class_init_start)})")

    return classifier, start_epoch, model_state_dict, optimizer_state_dict

def evaluate_model(model, bert_tokenizer, X_test, y_test, label_dict, numeric_dict, world_size, device, rank, debug, save_preds,
                   save_dir, X_test_sent, wandb_run=None, decimal=2, pos_label=1, threshold=0.5, save_plots=False,
                   model_name=None, weights_name=None):
    eval_start = time.time()
    print(f"\n{sky_blue}Evaluating model...{reset}") if rank == 0 else None
    model.model.eval()
    with torch.no_grad():
        print("Making predictions...") if rank == 0 and debug else None
        if model.finetune_bert:
            dataset = SentimentDataset(X_test, [0] * len(X_test), bert_tokenizer)  # Dummy labels
            dataloader = DataLoader(dataset, batch_size=model.batch_size, shuffle=False)
            preds = []
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = model.model(input_ids, attention_mask=attention_mask)
                preds.append(outputs)
            preds = torch.cat(preds, dim=0)
        else:
            if not torch.is_tensor(X_test):
                X_test = torch.tensor(X_test, device=device)
            preds = model.model(X_test)
        all_preds = [torch.zeros_like(preds) for _ in range(world_size)]
        dist.all_gather(all_preds, preds)
        if rank == 0:
            all_preds = torch.cat(all_preds, dim=0)[:len(y_test)]
            y_pred = convert_numeric_to_labels(all_preds.argmax(dim=1).cpu().numpy(), numeric_dict)
            print(f"Predictions: {len(y_pred)}, True labels: {len(y_test)}") if debug else None

            # Convert text labels to numeric labels
            y_test_numeric = np.array([label_dict[label] for label in y_test])
            y_pred_numeric = np.array([label_dict[label] for label in y_pred])

            # Set model name based on run name
            if model_name is None:
                if wandb_run is not None:
                    model_name = wandb_run.name
                elif weights_name is not None:
                    model_name = weights_name
                else:
                    model_name = 'Neural Classifier'
            
            # Use the DataWaza eval_model function
            metrics = eval_model(
                y_test=y_test_numeric,
                y_pred=y_pred_numeric,
                class_map=numeric_dict,
                estimator=model,
                x_test=X_test,
                class_type='multi' if len(numeric_dict) > 2 else 'binary',
                model_name=model_name,
                plot=False,
                save_plots=save_plots,
                save_dir=save_dir,
                debug=debug,
                pos_label=pos_label,
                decimal=decimal,
                return_metrics=True,
                threshold=threshold,
                wandb_run=wandb_run
            )

            # Save predictions if requested
            if save_preds:
                df = pd.DataFrame({
                    'X_test_sent': X_test_sent,
                    'y_test': y_test,
                    'y_pred': y_pred
                })
                # Create a save directory if it doesn't exist
                if not os.path.exists(save_dir):
                    print(f"Creating save directory: {save_dir}")
                    os.makedirs(save_dir)
                # Create a filename timestamp
                timestamp = time.strftime("%Y%m%d-%H%M%S")
                save_path = os.path.join(save_dir, f'predictions_{timestamp}.csv')
                df.to_csv(save_path, index=False)
                print(f"Saved predictions: {save_dir}/predictions_{timestamp}.csv")
                if wandb_run is not None:
                    wandb_run.log({
                        "eval/predictions": wandb.Table(
                            data=[[sent, true, pred] for sent, true, pred in zip(X_test_sent, y_test, y_pred)],
                            columns=["X_test_sent", "y_test", "y_pred"]
                        )
                    })
            #print(f"\n{bright_white}{bold}Classification report:{reset}")
            #print(classification_report(y_dev, preds_labels, digits=3, zero_division=0))
            class_report = classification_report(y_test, y_pred, digits=decimal, zero_division=0, output_dict=True)

            # Create a confusion matrix
            cm = confusion_matrix(y_test, y_pred, labels=list(numeric_dict.values()))

            macro_f1_score = model.score(X_test, y_test, device, debug)
            print(f"\n{bright_white}{bold}Macro F1 Score:{reset} {bright_cyan}{bold}{macro_f1_score:.2f}{reset}")

            # Log evaluation metrics to Weights & Biases
            if wandb_run is not None:
                wandb.log({
                    'eval/macro_f1_score': macro_f1_score,
                    'eval/classification_report': class_report,
                    #'eval/confusion_matrix': cm,
                    'eval/metrics': metrics,
                })

            print(f"\nEvaluation completed ({format_time(time.time() - eval_start)})")

def make_predictions(classifier, tokenizer, transformer_model, predict_file, numeric_dict, rank, debug, save_dir, device, pooling,
                     world_size, batch_size, num_workers, prefetch, empty_cache, finetune_transformer, freeze_transformer, chunk_size):
    predictions_start = time.time()
    print(f"\n{sky_blue}Predicting on unlabled test dataset...{reset}") if rank == 0 else None
    # Load the test dataset
    test_df = pd.read_csv(predict_file, index_col=None)
    test_texts = test_df.sentence.values
    print(f"Loaded test dataset at: {predict_file}") if rank == 0 else None
    print(f"Test dataset size: {len(test_texts)}") if rank == 0 else None
    print(f"Test dataset columns: {list(test_df.columns)}") if rank == 0 else None
    print(f"Test dataset sample:\n{test_df[['sentence']].sample(3)}") if rank == 0 else None
            
    # Tokenize and encode the test dataset
    if not finetune_transformer:
        X_test = bert_phi(test_texts, tokenizer, transformer_model, pooling, world_size, device, batch_size, 
                                    None, rank, debug, 'Test', 
                                    num_workers, prefetch, empty_cache, None, None)
    else:
        X_test = test_texts

    if rank == 0:
        dataset = SentimentDataset(X_test, None, tokenizer)
        dataloader = DataLoader(dataset, batch_size=classifier.batch_size, shuffle=False)
        classifier.model.eval()
        preds = []
        with torch.no_grad():
            if finetune_transformer:
                dataset = SentimentDataset(X_test, [0] * len(X_test), tokenizer)
                dataloader = DataLoader(dataset, batch_size=classifier.batch_size, shuffle=False)
                preds = []
                for batch in dataloader:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    outputs = classifier.model(input_ids, attention_mask=attention_mask)
                    preds.append(outputs)
                preds = torch.cat(preds, dim=0)
            else:
                if not torch.is_tensor(X_test):
                    X_test = torch.tensor(X_test, device=device)
                preds = classifier.model(X_test)
            preds_labels = convert_numeric_to_labels(preds.argmax(dim=1).cpu().numpy(), numeric_dict)
            test_df['prediction'] = preds_labels
            print(f"Sample test predictions:\n{test_df[['sentence', 'prediction']].sample(3)}")

            # Create a save directory if it doesn't exist
            if not os.path.exists(save_dir):
                print(f"Creating save directory: {save_dir}")
                os.makedirs(save_dir)
            # Create a filename timestamp
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            save_path = os.path.join(save_dir, f'test_predictions_{timestamp}.csv')
            test_df.to_csv(save_path, index=False)
            
            print(f"Saved test predictions: {save_dir}/test_predictions_{timestamp}.csv")
            print(f"Test prediction completed ({format_time(time.time() - predictions_start)})")
    dist.barrier()


In [11]:
device = prepare_device(0, 'cpu')

In [12]:
def setup_environment(rank, world_size, backend, device, debug, port='12355'):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = port
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    print(f"Rank {rank} - Device: {device}")
    dist.barrier()
    if rank == 0:
        print(f"{world_size} process groups initialized with '{backend}' backend on {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")


In [None]:
setup_environment(0, 1, 'gloo', device, False, '12356')

In [15]:
fix_random_seeds(42)

In [18]:
train_dyn1, val_dyn1, test_dyn1 = load_data(dataset='dynasent_r1', eval_dataset=None, eval_split='test',
                                            use_val_split=True, val_percent=0, sample_percent=None,
                                            world_size=1, rank=0, debug=False)


[38;5;117mLoading data...[0m
Using the same dataset for training and evaluation
Splits:
- Train: Using dynasent_r1 'train' split
- Validation: Using dynasent_r1 'validation' split
- Evaluation: Using dynasent_r1 'test' split
Train Data: DynaSent Round 1 from Hugging Face: 'dynabench/dynasent'
Validation Data: DynaSent Round 1 from Hugging Face: 'dynabench/dynasent'
Evaluation Data: DynaSent Round 1 from Hugging Face: 'dynabench/dynasent'
Train size: 80488
Validation size: 3600
Evaluation size: 3600
Train label distribution:
	      Negative: 14021
	       Neutral: 45076
	      Positive: 21391
Validation label distribution:
	      Negative: 1200
	       Neutral: 1200
	      Positive: 1200
Evaluation label distribution:
	      Negative: 1200
	       Neutral: 1200
	      Positive: 1200
Data loaded (2s)


In [20]:
train_sst, val_sst, test_sst = load_data(dataset='sst_local', eval_dataset=None, eval_split='test',
                                            use_val_split=True, val_percent=0, sample_percent=None,
                                            world_size=1, rank=0, debug=False)


[38;5;117mLoading data...[0m
Using the same dataset for training and evaluation
Splits:
- Train: Using sst_local 'train' split
- Validation: Using sst_local 'validation' split
- Evaluation: Using sst_local 'test' split
Train Data: Stanford Sentiment Treebank (SST) from Local: 'data/sentiment'
Validation Data: Stanford Sentiment Treebank (SST) from Local: 'data/sentiment'
Evaluation Data: Stanford Sentiment Treebank (SST) from Local: 'data/sentiment'
Train size: 8544
Validation size: 1101
Evaluation size: 2210
Train label distribution:
	      Negative: 3310
	       Neutral: 1624
	      Positive: 3610
Validation label distribution:
	      Negative: 428
	       Neutral: 229
	      Positive: 444
Evaluation label distribution:
	      Negative: 912
	       Neutral: 389
	      Positive: 909
Data loaded (276ms)


In [21]:
train_dyn2, val_dyn2, test_dyn2 = load_data(dataset='dynasent_r2', eval_dataset=None, eval_split='test',
                                            use_val_split=True, val_percent=0, sample_percent=None,
                                            world_size=1, rank=0, debug=False)


[38;5;117mLoading data...[0m
Using the same dataset for training and evaluation
Splits:
- Train: Using dynasent_r2 'train' split
- Validation: Using dynasent_r2 'validation' split
- Evaluation: Using dynasent_r2 'test' split
Train Data: DynaSent Round 2 from Hugging Face: 'dynabench/dynasent'
Validation Data: DynaSent Round 2 from Hugging Face: 'dynabench/dynasent'
Evaluation Data: DynaSent Round 2 from Hugging Face: 'dynabench/dynasent'
Train size: 13065
Validation size: 720
Evaluation size: 720
Train label distribution:
	      Negative: 4579
	       Neutral: 2448
	      Positive: 6038
Validation label distribution:
	      Negative: 240
	       Neutral: 240
	      Positive: 240
Evaluation label distribution:
	      Negative: 240
	       Neutral: 240
	      Positive: 240
Data loaded (2s)


In [27]:
train_dyn1['source'] = 'dynasent_r1'
train_dyn2['source'] = 'dynasent_r2'
train_sst['source'] = 'sst_local'
val_dyn1['source'] = 'dynasent_r1'
val_dyn2['source'] = 'dynasent_r2'
val_sst['source'] = 'sst_local'
test_dyn1['source'] = 'dynasent_r1'
test_dyn2['source'] = 'dynasent_r2'
test_sst['source'] = 'sst_local'

In [28]:
train_dyn1['split'] = 'train'
train_dyn2['split'] = 'train'
train_sst['split'] = 'train'
val_dyn1['split'] = 'validation'
val_dyn2['split'] = 'validation'
val_sst['split'] = 'validation'
test_dyn1['split'] = 'test'
test_dyn2['split'] = 'test'
test_sst['split'] = 'test'

In [29]:
train_dyn1[['sentence', 'label', 'source', 'split']].head()

Unnamed: 0,sentence,label,source,split
0,Roto-Rooter is always good when you need someo...,positive,dynasent_r1,train
1,It's so worth the price of cox service over he...,positive,dynasent_r1,train
2,"I placed my order of ""sticky ribs"" as an appet...",neutral,dynasent_r1,train
3,"There is mandatory valet parking, so make sure...",neutral,dynasent_r1,train
4,My wife and I couldn't finish it.,neutral,dynasent_r1,train


In [30]:
train_dyn2[['sentence', 'label', 'source', 'split']].head()

Unnamed: 0,sentence,label,source,split
0,We enjoyed our first and last meal in Toronto ...,positive,dynasent_r2,train
1,I tried a new place. I can't wait to return an...,positive,dynasent_r2,train
2,"The buffalo chicken was not good, but very cos...",negative,dynasent_r2,train
3,The hotel offered complimentary breakfast.,positive,dynasent_r2,train
4,It work very well,positive,dynasent_r2,train


In [31]:
train_sst[['sentence', 'label', 'source', 'split']].head()

Unnamed: 0,sentence,label,source,split
0,The Rock is destined to be the 21st Century 's...,positive,sst_local,train
71,The gorgeously elaborate continuation of `` Th...,positive,sst_local,train
144,Singer\/composer Bryan Adams contributes a sle...,positive,sst_local,train
221,You 'd think by now America would have had eno...,neutral,sst_local,train
258,Yet the act is still charming here .,positive,sst_local,train


In [32]:
train_all = pd.concat([train_dyn1[['sentence', 'label', 'source', 'split']], train_dyn2[['sentence', 'label', 'source', 'split']], train_sst[['sentence', 'label', 'source', 'split']]], ignore_index=True)

In [33]:
train_all.sample(10)

Unnamed: 0,sentence,label,source,split
84457,Those 2 drinks are part of the HK culture and ...,negative,dynasent_r2,train
33315,I was told by the repair company that was doin...,negative,dynasent_r1,train
95755,It is there to give them a good time .,neutral,sst_local,train
99353,Like leafing through an album of photos accomp...,negative,sst_local,train
23628,Johnny was a talker and liked to have fun.,positive,dynasent_r1,train
81277,It as burnt to a crisp black flavorless,negative,dynasent_r2,train
61427,I called Moveaholics Jason was amazing he even...,positive,dynasent_r1,train
13393,Is this place expensive?,neutral,dynasent_r1,train
11373,It's likely crowded at the busier times so kee...,neutral,dynasent_r1,train
59563,I was just looking at bottom line price and my...,neutral,dynasent_r1,train


In [34]:
val_all = pd.concat([val_dyn1[['sentence', 'label', 'source', 'split']], val_dyn2[['sentence', 'label', 'source', 'split']], val_sst[['sentence', 'label', 'source', 'split']]], ignore_index=True)

In [35]:
val_all.sample(10)

Unnamed: 0,sentence,label,source,split
1381,I do like Tasty Asian kitchen across the street.,positive,dynasent_r1,validation
2337,"Sadly, they get mixed up with the other locati...",negative,dynasent_r1,validation
5213,"Professionally speaking , it 's tempting to ju...",negative,sst_local,validation
3691,The dining was sublime.,positive,dynasent_r2,validation
1234,Came back to Elara around 4am.,neutral,dynasent_r1,validation
4091,You dont have to miss church and the game. tru...,neutral,dynasent_r2,validation
2786,"As long as I can get my job done, I'll wait.",neutral,dynasent_r1,validation
2427,They are doing themselves a great disservice b...,negative,dynasent_r1,validation
5400,The end result is a film that 's neither .,negative,sst_local,validation
1345,I HAVE NEVER BEEN SO APPALLED BY THE MEDICAL C...,negative,dynasent_r1,validation


In [36]:
test_all = pd.concat([test_dyn1[['sentence', 'label', 'source', 'split']], test_dyn2[['sentence', 'label', 'source', 'split']], test_sst[['sentence', 'label', 'source', 'split']]], ignore_index=True)

In [37]:
test_all.sample(10)

Unnamed: 0,sentence,label,source,split
934,"If I lived on the Strip, this bar would be my ...",positive,dynasent_r1,test
5091,"A recent favourite at Sundance , this white-tr...",positive,sst_local,test
6511,Rainy days and movies about the disintegration...,negative,sst_local,test
57,As our appetizers arrived (hummus for hubby an...,positive,dynasent_r1,test
3859,"That really sucks, was it your fault?",negative,dynasent_r2,test
3674,"Hence, that's why they do not take walk ins.",neutral,dynasent_r2,test
2040,We checked in around 12:30pm and had some frie...,positive,dynasent_r1,test
586,"HOOD RAT, HOOD RAT, HOOCHIE MAMA!",negative,dynasent_r1,test
4380,Disney has always been hit-or-miss when bringi...,neutral,sst_local,test
1900,Let me start by saying the food is good.,positive,dynasent_r1,test


In [38]:
print(train_all.shape, val_all.shape, test_all.shape)

(102097, 4) (5421, 4) (6530, 4)


In [39]:
train_all.head()

Unnamed: 0,sentence,label,source,split
0,Roto-Rooter is always good when you need someo...,positive,dynasent_r1,train
1,It's so worth the price of cox service over he...,positive,dynasent_r1,train
2,"I placed my order of ""sticky ribs"" as an appet...",neutral,dynasent_r1,train
3,"There is mandatory valet parking, so make sure...",neutral,dynasent_r1,train
4,My wife and I couldn't finish it.,neutral,dynasent_r1,train


In [40]:
train_all.tail()

Unnamed: 0,sentence,label,source,split
102092,A real snooze .,negative,sst_local,train
102093,No surprises .,negative,sst_local,train
102094,We 've seen the hippie-turned-yuppie plot befo...,positive,sst_local,train
102095,Her fans walked out muttering words like `` ho...,negative,sst_local,train
102096,In this case zero .,negative,sst_local,train


In [41]:
# Randomly shuffle the training data and dev data
train_all = train_all.sample(frac=1, random_state=42).reset_index(drop=True)
val_all = val_all.sample(frac=1, random_state=42).reset_index(drop=True)
test_all = test_all.sample(frac=1, random_state=42).reset_index(drop=True)

In [42]:
train_all.head()

Unnamed: 0,sentence,label,source,split
0,Those 2 drinks are part of the HK culture and ...,negative,dynasent_r2,train
1,I was told by the repair company that was doin...,negative,dynasent_r1,train
2,It is there to give them a good time .,neutral,sst_local,train
3,Like leafing through an album of photos accomp...,negative,sst_local,train
4,Johnny was a talker and liked to have fun.,positive,dynasent_r1,train


In [43]:
train_all.tail()

Unnamed: 0,sentence,label,source,split
102092,I thought this place was supposed to be good.,negative,dynasent_r1,train
102093,They claim it's because people didn't like it ...,negative,dynasent_r1,train
102094,There is also another marbled-out full bathroo...,neutral,dynasent_r1,train
102095,You put in your cell phone number & select a d...,neutral,dynasent_r1,train
102096,I came in for a second opinion on a crown I wa...,neutral,dynasent_r1,train


In [44]:
val_all.head()

Unnamed: 0,sentence,label,source,split
0,Found Thai Spoon on the Vegan Pittsburgh website.,neutral,dynasent_r1,validation
1,Our bill came out to around $27 and we ate lik...,positive,dynasent_r1,validation
2,State Farm broke down the costs for me of the ...,neutral,dynasent_r1,validation
3,The only con for this resto is the wait to get...,negative,dynasent_r1,validation
4,We could hear the people above us stomping aro...,negative,dynasent_r1,validation


In [45]:
val_all.tail()

Unnamed: 0,sentence,label,source,split
5416,I think it's really a matter of mastering the ...,neutral,dynasent_r2,validation
5417,A bloated gasbag thesis grotesquely impressed ...,negative,sst_local,validation
5418,"Its story may be a thousand years old , but wh...",negative,sst_local,validation
5419,I felt sad for Lise not so much because of wha...,neutral,sst_local,validation
5420,"We always eat in the restaurant, so I can't co...",neutral,dynasent_r1,validation


In [46]:
test_all.head()

Unnamed: 0,sentence,label,source,split
0,I had called in advance to see if they offered...,positive,dynasent_r1,test
1,EVERY SINGLE ITEM WAS INEDIBLE.,negative,dynasent_r1,test
2,Rooms are small.,negative,dynasent_r1,test
3,"Without resorting to hyperbole , I can state t...",positive,sst_local,test
4,Used a $12 Groupon deal (to cover $20 meal).,neutral,dynasent_r1,test


In [47]:
test_all.tail()

Unnamed: 0,sentence,label,source,split
6525,I went back in to ask for cilantro dressing th...,positive,dynasent_r2,test
6526,"Here , Adrian Lyne comes as close to profundit...",positive,sst_local,test
6527,The actors are so terrific at conveying their ...,neutral,sst_local,test
6528,It should be mentioned that the set design and...,positive,sst_local,test
6529,She greeted customers by holding the scanner t...,negative,dynasent_r1,test


In [48]:
# Save the combined training and dev data to CSV files
train_all.to_csv('data/merged/train_all.csv', index=False)
val_all.to_csv('data/merged/val_all.csv', index=False)
test_all.to_csv('data/merged/test_all.csv', index=False)

In [50]:
train_all_df = pd.read_csv('data/merged/train_all.csv')
val_all_df = pd.read_csv('data/merged/val_all.csv')
test_all_df = pd.read_csv('data/merged/test_all.csv')

In [52]:
print(train_all_df.shape, val_all_df.shape, test_all_df.shape)

(102097, 4) (5421, 4) (6530, 4)


In [51]:
train_all_df.head()

Unnamed: 0,sentence,label,source,split
0,Those 2 drinks are part of the HK culture and ...,negative,dynasent_r2,train
1,I was told by the repair company that was doin...,negative,dynasent_r1,train
2,It is there to give them a good time .,neutral,sst_local,train
3,Like leafing through an album of photos accomp...,negative,sst_local,train
4,Johnny was a talker and liked to have fun.,positive,dynasent_r1,train


In [53]:
# Create a binary classification dataset for neutral vs. non-neutral
train_all_df['neutral_label'] = train_all_df['label'].apply(lambda x: 'neutral' if x == 'neutral' else 'non-neutral')
val_all_df['neutral_label'] = val_all_df['label'].apply(lambda x: 'neutral' if x == 'neutral' else 'non-neutral')
test_all_df['neutral_label'] = test_all_df['label'].apply(lambda x: 'neutral' if x == 'neutral' else 'non-neutral')

In [54]:
# Create a binary classification dataset for positive vs. non-positive
train_all_df['positive_label'] = train_all_df['label'].apply(lambda x: 'positive' if x == 'positive' else 'non-positive')
val_all_df['positive_label'] = val_all_df['label'].apply(lambda x: 'positive' if x == 'positive' else 'non-positive')
test_all_df['positive_label'] = test_all_df['label'].apply(lambda x: 'positive' if x == 'positive' else 'non-positive')

In [55]:
# Create a binary classification dataset for negative vs. non-negative
train_all_df['negative_label'] = train_all_df['label'].apply(lambda x: 'negative' if x == 'negative' else 'non-negative')
val_all_df['negative_label'] = val_all_df['label'].apply(lambda x: 'negative' if x == 'negative' else 'non-negative')
test_all_df['negative_label'] = test_all_df['label'].apply(lambda x: 'negative' if x == 'negative' else 'non-negative')

In [56]:
train_all_df.head()

Unnamed: 0,sentence,label,source,split,neutral_label,positive_label,negative_label
0,Those 2 drinks are part of the HK culture and ...,negative,dynasent_r2,train,non-neutral,non-positive,negative
1,I was told by the repair company that was doin...,negative,dynasent_r1,train,non-neutral,non-positive,negative
2,It is there to give them a good time .,neutral,sst_local,train,neutral,non-positive,non-negative
3,Like leafing through an album of photos accomp...,negative,sst_local,train,non-neutral,non-positive,negative
4,Johnny was a talker and liked to have fun.,positive,dynasent_r1,train,non-neutral,positive,non-negative


In [57]:
# Save the binary classification training and dev data to CSV files
train_all_df.to_csv('data/merged/train_all_binary.csv', index=False)
val_all_df.to_csv('data/merged/val_all_binary.csv', index=False)
test_all_df.to_csv('data/merged/test_all_binary.csv', index=False)

In [58]:
# Load all the datasets from CSV files
train_all_df = pd.read_csv('data/merged/train_all.csv')
val_all_df = pd.read_csv('data/merged/val_all.csv')
test_all_df = pd.read_csv('data/merged/test_all.csv')

In [59]:
def oversample_minority_datasets(df, source_col='source', multipliers=None, balanced=True):
    """
    Oversample minority datasets by duplicating their rows.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        Input DataFrame containing the data
    source_col : str, optional
        Name of the column containing dataset source information
    multipliers : dict, optional
        Custom multipliers for each dataset, e.g. {'dynasent_r1': 1, 'dynasent_r2': 3, 'sst_local': 3}
        If None and balanced=True, will calculate multipliers to roughly balance datasets
    balanced : bool, optional
        If True and multipliers not provided, automatically calculate multipliers to make datasets balanced
        
    Returns:
    --------
    pandas.DataFrame
        New DataFrame with oversampled minority datasets
    """
    # Get dataset sizes
    dataset_counts = df[source_col].value_counts()
    max_count = dataset_counts.max()
    
    if multipliers is None:
        if balanced:
            # Calculate multipliers to roughly balance datasets
            multipliers = {
                dataset: int(np.ceil(max_count / count))
                for dataset, count in dataset_counts.items()
            }
        else:
            # Default to no oversampling
            multipliers = {dataset: 1 for dataset in dataset_counts.index}
    
    # Print statistics before oversampling
    print("\nBefore oversampling:")
    print("-" * 60)
    print(f"{'Dataset':<15} {'Count':<10} {'Percent':<10} {'Multiplier':<10} {'New Count':<10}")
    print("-" * 60)
    
    total = len(df)
    oversampled_dfs = []
    
    for dataset in dataset_counts.index:
        # Get multiplier for this dataset
        mult = multipliers.get(dataset, 1)
        
        # Get rows for this dataset
        dataset_df = df[df[source_col] == dataset]
        count = len(dataset_df)
        
        # Calculate new count after oversampling
        new_count = count * mult
        
        # Print statistics
        print(f"{dataset:<15} {count:<10} {count/total*100:>6.2f}%    {mult:>6.0f}x    {new_count:<10}")
        
        # Duplicate rows based on multiplier
        if mult > 1:
            dataset_df = pd.concat([dataset_df] * mult, ignore_index=True)
        
        oversampled_dfs.append(dataset_df)
    
    # Combine all datasets
    result_df = pd.concat(oversampled_dfs, ignore_index=True)
    
    # Print final statistics
    print("\nAfter oversampling:")
    print(f"Total samples: {len(result_df):,} (before: {len(df):,})")
    new_counts = result_df[source_col].value_counts()
    for dataset, count in new_counts.items():
        print(f"{dataset:<15} {count:<10} {count/len(result_df)*100:>6.2f}%")
    
    # Shuffle the DataFrame
    result_df = result_df.sample(frac=1.0).reset_index(drop=True)
    
    return result_df

def print_label_distribution(df, label_col='label', source_col='source'):
    """
    Print the distribution of labels within each dataset source.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        Input DataFrame containing the data
    label_col : str, optional
        Name of the column containing labels
    source_col : str, optional
        Name of the column containing dataset source information
    """
    print("\nLabel distribution by dataset:")
    print("-" * 60)
    
    for dataset in df[source_col].unique():
        dataset_df = df[df[source_col] == dataset]
        total = len(dataset_df)
        
        print(f"\n{dataset} (Total: {total:,}):")
        label_counts = dataset_df[label_col].value_counts()
        for label, count in label_counts.items():
            print(f"{label:<10} {count:<10} {count/total*100:>6.2f}%")

In [60]:
# Automatically balance datasets
train_balanced_df = oversample_minority_datasets(train_all_df, balanced=True)


Before oversampling:
------------------------------------------------------------
Dataset         Count      Percent    Multiplier New Count 
------------------------------------------------------------
dynasent_r1     80488       78.83%         1x    80488     
dynasent_r2     13065       12.80%         7x    91455     
sst_local       8544         8.37%        10x    85440     

After oversampling:
Total samples: 257,383 (before: 102,097)
dynasent_r2     91455       35.53%
sst_local       85440       33.20%
dynasent_r1     80488       31.27%


In [61]:
train_balanced_df.to_csv('data/merged/train_balanced.csv', index=False)

In [62]:
print(train_all_df.shape, val_all_df.shape, test_all_df.shape)

(102097, 4) (5421, 4) (6530, 4)


In [63]:
test_all_df.head()

Unnamed: 0,sentence,label,source,split
0,I had called in advance to see if they offered...,positive,dynasent_r1,test
1,EVERY SINGLE ITEM WAS INEDIBLE.,negative,dynasent_r1,test
2,Rooms are small.,negative,dynasent_r1,test
3,"Without resorting to hyperbole , I can state t...",positive,sst_local,test
4,Used a $12 Groupon deal (to cover $20 meal).,neutral,dynasent_r1,test


In [64]:
test_all_df.tail()

Unnamed: 0,sentence,label,source,split
6525,I went back in to ask for cilantro dressing th...,positive,dynasent_r2,test
6526,"Here , Adrian Lyne comes as close to profundit...",positive,sst_local,test
6527,The actors are so terrific at conveying their ...,neutral,sst_local,test
6528,It should be mentioned that the set design and...,positive,sst_local,test
6529,She greeted customers by holding the scanner t...,negative,dynasent_r1,test
