In [1]:
import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", message="`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.", category=FutureWarning)
warnings.filterwarnings("ignore", message="TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class.", category=UserWarning, module=r"torch\._utils")
warnings.filterwarnings("ignore", message="promote has been superseded by promote_options='default'", category=FutureWarning, module=r"datasets\.table")

# Standard library imports
import argparse
import os
import sys
import signal
import time
from collections import Counter
import traceback
import math
from multiprocessing import Value
from queue import Empty

# PyTorch imports
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer

# Third-party library imports
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
from sklearn.metrics import classification_report
import sst
from datasets import load_dataset

# Custom utility imports
from utils import (
    setup_environment, 
    prepare_device, 
    fix_random_seeds,
    convert_numeric_to_labels, 
    convert_labels_to_tensor,
    format_time, 
    convert_sst_label, 
    get_activation,
    set_threads,
    signal_handler,
    cleanup_and_exit,
    get_optimizer,
    get_shape_color,
    print_rank_memory_summary,
    #get_scheduler
)
from torch_ddp_finetune_neural_classifier import TorchDDPNeuralClassifier, SentimentDataset
from colors import *

# Suppress Hugging Face library warnings
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("datasets").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub.repocard").setLevel(logging.ERROR)

from torch.utils.data import Dataset

from multiprocessing import Queue, Pipe, set_start_method

def save_data_archive(X_train, X_dev, y_train, y_dev, X_dev_sent, world_size, device_type, data_dir):
    # Create directory if it doesn't exist
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    # Create filename with appropriate suffix and timestamp
    suffix = f'_{world_size}_gpu' if device_type == 'cuda' else '_1_cpu'
    timestamp = time.strftime('%Y%m%d-%H%M%S')
    filename = f'data{suffix}_{timestamp}.npz'
    filepath = os.path.join(data_dir, filename)

    # Save data to archive file
    np.savez_compressed(filepath, X_train=X_train, X_dev=X_dev, y_train=y_train, y_dev=y_dev, X_dev_sent=X_dev_sent)
    print(f"\nData saved to: {filepath}")

def load_data_archive(data_file, device, rank):
    load_archive_start = time.time()
    
    # Check if the archive file path is provided
    if data_file is None:
        raise ValueError("No archive file provided to load data from")
    
    # Check if the archive file exists
    if not os.path.exists(data_file):
        raise FileNotFoundError(f"Archive file not found: {data_file}")
    
    # Attempt to load the data from the archive file
    try:
        print(f"\nLoading archived data from: {data_file}...") if rank == 0 else None
        with np.load(data_file, allow_pickle=True) as data:
            X_train = data['X_train']
            X_dev = data['X_dev']
            y_train = data['y_train']
            y_dev = data['y_dev']
            X_dev_sent = data['X_dev_sent']
        print(f"Archived data loaded ({format_time(time.time() - load_archive_start)})") if rank == 0 else None
    except Exception as e:
        raise RuntimeError(f"Failed to load data from archive file {data_file}: {str(e)}")
    
    return X_train, X_dev, y_train, y_dev, X_dev_sent

def initialize_bert_model(weights_name, device, rank, debug):
    model_init_start = time.time()
    print(f"\nInitializing '{weights_name}' tokenizer and model...") if rank == 0 else None
    bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
    bert_model = BertModel.from_pretrained(weights_name).to(device)
    # if device.type == 'cuda':
    #     bert_model = DDP(bert_model, device_ids=[rank], output_device=rank, static_graph=True)
    # else:
    #     bert_model = DDP(bert_model, device_ids=None, output_device=None, static_graph=True)
    dist.barrier()
    if rank == 0:
        if debug:
            print(f"Bert Tokenizer:\n{bert_tokenizer}")
            print(f"Bert Model:\n{bert_model}")
        print(f"Tokenizer and model initialized ({format_time(time.time() - model_init_start)})")
    return bert_tokenizer, bert_model

def load_data(dataset, eval_dataset, sample_percent, world_size, rank, debug):
    data_load_start = time.time()

    # Function to get a subset of data based on split name
    def get_split(data, split):
            if split == 'train':
                data_split = data['train'].to_pandas()
            elif split == 'dev':
                data_split = data['validation'].to_pandas()
            elif split == 'test':
                data_split = data['test'].to_pandas()
            else:
                raise ValueError(f"Unknown split: {split}")
            return data_split
    
    def print_label_dist(dataset, label_name='label'):
        dist = sorted(Counter(dataset[label_name]).items())
        for k, v in dist:
            print(f"\t{k.capitalize():>14s}: {v}")

    # Function to load data from Hugging Face or local based on ID and split name
    def get_data(id, split, rank, debug):
        # Identify the dataset and path from the ID
        dataset_source = 'Hugging Face'
        dataset_subset = None
        dataset_url = None
        if id == 'sst_local':
            dataset_name = 'Stanford Sentiment Treebank (SST)'
            dataset_source = 'Local'
            dataset_path = os.path.join('data', 'sentiment')
        elif id == 'sst':
            dataset_name = 'Stanford Sentiment Treebank (SST)'
            dataset_url = 'https://huggingface.co/datasets/gimmaru/SetFit-sst5'
            dataset_path = 'SetFit/sst5'
        elif id in ['dynasent', 'dynasent_r1']:
            dataset_name = 'DynaSent Round 1'
            dataset_url = 'https://huggingface.co/datasets/dynabench/dynasent'
            dataset_path = 'dynabench/dynasent'
            dataset_subset = 'dynabench.dynasent.r1.all'
        elif id == 'dynasent_r2':
            dataset_name = 'DynaSent Round 2'
            dataset_url = 'https://huggingface.co/datasets/dynabench/dynasent'
            dataset_path = 'dynabench/dynasent'
            dataset_subset = 'dynabench.dynasent.r2.all'
        elif id == 'mteb_tweet':
            dataset_name = 'MTEB Tweet Sentiment Extraction'
            dataset_url = 'https://huggingface.co/datasets/mteb/tweet_sentiment_extraction'
            dataset_path = 'mteb/tweet_sentiment_extraction'
        else:
            raise ValueError(f"Unknown dataset: {id}")
        print(f"{split.capitalize()} Data: {dataset_name} from {dataset_source}: '{dataset_path}'") if rank == 0 else None
        print(f"Dataset URL: {dataset_url}") if dataset_url is not None and rank == 0 else None

        # Load the dataset, do any pre-processing, and select appropriate split
        if id == 'sst_local':
            data_split = sst.train_reader(dataset_path) if split == 'train' else sst.dev_reader(dataset_path)
        elif id == 'sst':
            data = load_dataset(dataset_path)
            data = data.rename_column('label', 'label_orig') 
            for split_name in ('train', 'validation', 'test'):
                dis = [convert_sst_label(s) for s in data[split_name]['label_text']]
                data[split_name] = data[split_name].add_column('label', dis)
                data[split_name] = data[split_name].add_column('sentence', data[split_name]['text'])
            data_split = get_split(data, split)
        elif id in ['dynasent', 'dynasent_r1']:
            data = load_dataset(dataset_path, dataset_subset)
            data = data.rename_column('gold_label', 'label')
            data_split = get_split(data, split)
        elif id == 'dynasent_r2':
            data = load_dataset(dataset_path, dataset_subset)
            data = data.rename_column('gold_label', 'label')
            data_split = get_split(data, split)
        elif id == 'mteb_tweet':
            data = load_dataset(dataset_path)
            data = data.rename_column('label', 'label_orig') 
            data = data.rename_column('label_text', 'label')
            data = data.rename_column('text', 'sentence')
            split = 'test' if split == 'dev' else split
            data_split = get_split(data, split)

        return data_split

    if rank == 0:
        print(f"\nLoading data...")
        if eval_dataset is not None:
            print("Using different datasets for training and evaluation")
        else:
            eval_dataset = dataset
            print("Using the same dataset for training and evaluation")
        
        train = get_data(dataset, 'train', rank, debug)
        dev = get_data(eval_dataset, 'dev', rank, debug)
        
        print(f"Train size: {len(train)}, Dev size: {len(dev)}")

        if sample_percent is not None:
            print(f"Sampling {sample_percent:.0%} of data...")
            train = train.sample(frac=sample_percent)
            dev = dev.sample(frac=sample_percent)
            print(f"Sampled Train size: {len(train)}, Sampled Dev size: {len(dev)}")

    else:
        train = None
        dev = None

    # Broadcast the data to all ranks
    if world_size > 1:
        object_list = [train, dev]
        dist.broadcast_object_list(object_list, src=0)
        train, dev = object_list

        dist.barrier()
        print(f"Data broadcasted to all ranks") if rank == 0 and debug else None
        print(f"Rank {rank}: Train size: {len(train)}, Dev size: {len(dev)}") if debug else None

    if rank == 0:
        print("Train label distribution:")
        print_label_dist(train)
        print("Dev label distribution:")
        print_label_dist(dev)
        print(f"Data loaded ({format_time(time.time() - data_load_start)})")
    dist.barrier()
        
    return train, dev

def process_data(bert_tokenizer, bert_model, pooling, world_size, train, dev, device, batch_size, rank, debug, save_archive,
                 save_dir, num_workers, prefetch, empty_cache, finetune_bert):
    data_process_start = time.time()

    print(f"\nProcessing data (Batch size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Fine Tune BERT: {finetune_bert})...") if rank == 0 else None
    print(f"Extracting sentences and labels...") if rank == 0 else None
    
    # Extract y labels
    y_train = train.label.values
    y_dev = dev.label.values
    
    # Extract X sentences
    X_train_sent = train.sentence.values
    X_dev_sent = dev.sentence.values

    if rank == 0:
        # Generate random indices
        train_indices = np.random.choice(len(X_train_sent), 3, replace=False)
        dev_indices = np.random.choice(len(X_dev_sent), 3, replace=False)
        
        # Collect sample sentences
        train_samples = []
        dev_samples = []
        for i in train_indices:
            train_samples.append((f'Train[{i}]: ', X_train_sent[i], f' - {y_train[i].upper()}'))
        for i in dev_indices:
            dev_samples.append((f'Dev[{i}]: ', X_dev_sent[i], f' - {y_dev[i].upper()}'))
    else:
        train_samples = None
        dev_samples = None
    
    # Process X sentences (tokenize and encode with BERT) if we're not fine-tuning BERT
    if finetune_bert:
        # For fine-tuning, we just return the sentences
        X_train = X_train_sent
        X_dev = X_dev_sent
    else:
        # Process X sentences (tokenize and encode with BERT) for non-fine-tuning workflow
        X_train = bert_phi(X_train_sent, bert_tokenizer, bert_model, pooling, world_size, device, batch_size, train_samples, rank,
                           debug, split='train', num_workers=num_workers, prefetch=prefetch,
                           empty_cache=empty_cache).cpu().numpy()
        X_dev = bert_phi(X_dev_sent, bert_tokenizer, bert_model, pooling, world_size, device, batch_size, dev_samples, rank,
                         debug, split='dev', num_workers=num_workers, prefetch=prefetch,
                         empty_cache=empty_cache).cpu().numpy()
    
    # Data integrity check, make sure the sizes are consistent across ranks
    if not finetune_bert and device.type == 'cuda' and world_size > 1:
        # Gather sizes from all ranks
        train_sizes = [torch.tensor(X_train.shape[0], device=device) for _ in range(world_size)]
        dev_sizes = [torch.tensor(X_dev.shape[0], device=device) for _ in range(world_size)]
        
        dist.all_gather(train_sizes, train_sizes[rank])
        dist.all_gather(dev_sizes, dev_sizes[rank])

        if rank == 0:
            # Convert to CPU for easier handling
            train_sizes = [size.cpu().item() for size in train_sizes]
            dev_sizes = [size.cpu().item() for size in dev_sizes]

            if debug:
                print("\nDataset size summary:")
                print(f"Train sizes across ranks: {train_sizes}")
                print(f"Dev sizes across ranks: {dev_sizes}")
                
                if len(set(train_sizes)) > 1 or len(set(dev_sizes)) > 1:
                    print("WARNING: Mismatch in dataset sizes across ranks!")
                    print(f"Train size mismatch: {max(train_sizes) - min(train_sizes)}")
                    print(f"Dev size mismatch: {max(dev_sizes) - min(dev_sizes)}")
                else:
                    print("All ranks have consistent dataset sizes.")
                
                print(f"Total train samples: {sum(train_sizes)}")
                print(f"Total dev samples: {sum(dev_sizes)}")

            # Check for significant mismatch and raise error if necessary
            max_mismatch = max(max(train_sizes) - min(train_sizes), max(dev_sizes) - min(dev_sizes))
            if max_mismatch > world_size:  # Allow for small mismatches due to uneven division
                raise ValueError(f"Significant mismatch in dataset sizes across ranks. Max difference: {max_mismatch}")

    if save_archive and rank == 0:
        save_data_archive(X_train, X_dev, y_train, y_dev, X_dev_sent, world_size, device.type, save_dir)

    dist.barrier()
    if rank == 0:
        print(f"X Train shape: {list(np.shape(X_train))}, X Dev shape: {list(np.shape(X_dev))}")
        print(f"y Train shape: {list(np.shape(y_train))}, y Dev shape: {list(np.shape(y_dev))}")
        print(f"Data processed ({format_time(time.time() - data_process_start)})")
    
    return X_train, X_dev, y_train, y_dev, X_dev_sent

def bert_phi(texts, tokenizer, model, pooling, world_size, device, batch_size, sample_texts, rank, debug, split, num_workers, prefetch, empty_cache):
    encoding_start = time.time()
    total_texts = len(texts)
    embeddings = []

    def tokenize(texts, tokenizer, device):
        encoded = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)
        return input_ids, attention_mask

    def pool(last_hidden_state, attention_mask, pooling):
        if pooling == 'cls':
            return last_hidden_state[:, 0, :]
        elif pooling == 'mean':
            return (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
        elif pooling == 'max':
            return torch.max(last_hidden_state * attention_mask.unsqueeze(-1), dim=1)[0]
        else:
            raise ValueError(f"Unknown pooling method: {pooling}")

    # Process and display sample texts first
    def display_sample_texts(sample_texts):
        for text in sample_texts:
            # Tokenize the text and get the tokens
            tokens = tokenizer.tokenize(text[1])
            print(f"{text[0]}{text[1]}{text[2]}")
            print(f"Tokens: {tokens}")
            
            # Encode the text (including special tokens) and get embeddings
            input_ids, attention_mask = tokenize([text[1]], tokenizer, device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)
            
            embedding = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            print(f"Embedding: {embedding[0, :6].cpu().numpy()} ...")
            print()

            if device.type == 'cuda':
                del input_ids, attention_mask, outputs, embedding
                torch.cuda.empty_cache()

    # Use DDP to distribute the encoding process across multiple GPUs
    if device.type == 'cuda' and world_size > 1: 
        if rank == 0:

            # Display sample texts
            print(f"\nDisplaying samples from {split.capitalize()} data:")
            display_sample_texts(sample_texts)

            print(f"\nEncoding {split.capitalize()} data of {total_texts} texts distributed across {world_size} GPUs...")
            print(f"Batch Size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Empty Cache: {empty_cache}")

        dist.barrier()
        # Calculate the number of texts that make the dataset evenly divisible by world_size
        texts_per_rank = math.ceil(total_texts / world_size)
        padded_total = texts_per_rank * world_size
        
        if padded_total > total_texts:
            print(f"Padding {split.capitalize()} data to {padded_total} texts for even distribution across {world_size} ranks...") if rank == 0 else None
        
        # Calculate number of padding texts needed
        padding_texts = padded_total - total_texts

        # Create padding texts using [PAD] token
        pad_text = tokenizer.pad_token * 10  # Arbitrary length, will be truncated if too long
        texts_with_padding = list(texts) + [pad_text] * padding_texts  # Convert texts to list and then concatenate

        # Distribute texts evenly across ranks
        start_idx = rank * texts_per_rank
        end_idx = start_idx + texts_per_rank
        local_texts = texts_with_padding[start_idx:end_idx]
        local_batch_count = len(local_texts) // batch_size + 1

        batch_count = len(texts) // batch_size + 1
        total_batches = batch_count

        if rank == 0:
            print(f"Texts per rank: {texts_per_rank}, Total batches: {total_batches}")
        
        dist.barrier()
        print(f"Rank {rank}: Processing {len(local_texts)} texts (indices {start_idx} to {end_idx-1}) in {local_batch_count} batches...")
        
        dist.barrier()
        model.eval()

        for i in range(0, len(local_texts), batch_size):
            batch_start = time.time()
            batch_texts = local_texts[i:i + batch_size]
            input_ids, attention_mask = tokenize(batch_texts, tokenizer, device)

            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)

            batch_embeddings = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            embeddings.append(batch_embeddings)

            batch_shape = list(batch_embeddings.shape)

            if empty_cache:
                # Delete the unused objects
                del outputs, input_ids, attention_mask, batch_embeddings
                # Empty CUDA cache
                torch.cuda.empty_cache()

            shape_color = get_shape_color(batch_size, batch_shape)
            print(f"Rank {bright_white}{bold}{rank}{reset}: Batch {bright_white}{bold}{(i // batch_size) + 1:2d}{reset} / {local_batch_count}, Shape: {shape_color}{bold}{batch_shape}{reset}, Time: {format_time(time.time() - batch_start)}")

            if rank == 0:
                if (i // batch_size) % 5 == 0:
                    print_rank_memory_summary(world_size, rank, all_local=True, verbose=False)

        local_embeddings = torch.cat(embeddings, dim=0)

        #dist.barrier()

        # Gather embeddings from all processes
        if world_size > 1:
            gathered_embeddings = [torch.zeros_like(local_embeddings) for _ in range(world_size)]
            dist.all_gather(gathered_embeddings, local_embeddings)
            all_embeddings = torch.cat(gathered_embeddings, dim=0)
            
            if rank == 0:  # Only one process needs to do this check
                total_embeddings = all_embeddings.shape[0]
                padding_embeddings = total_embeddings - total_texts
                
                print(f"Total embeddings: {total_embeddings}")
                print(f"Original texts: {total_texts}")
                print(f"Expected padding: {padding_texts}")
                print(f"Actual padding: {padding_embeddings}")
                
                if padding_embeddings != padding_texts:
                    print("WARNING: Mismatch in padding count!")
                
                if padding_embeddings > 0:
                    # Get the embeddings of the padded texts
                    padding_embeds = all_embeddings[-padding_embeddings:]
                    
                    # Calculate the maximum difference between any two padding embeddings
                    max_diff = torch.max(torch.pdist(padding_embeds))
                    
                    print(f"Maximum difference between padding embeddings: {max_diff}")
                    
                    # You can adjust this threshold based on your observations
                    if max_diff > 1e-6:
                        print("WARNING: Padding embeddings are not similar.")
                    else:
                        print("Padding embeddings verified as very similar.")

            # Now slice off the padding
            final_embeddings = all_embeddings[:total_texts]
        else:
            final_embeddings = local_embeddings

        dist.barrier()
        if rank == 0:
            print(f"Final embeddings shape: {list(final_embeddings.shape)}") if debug else None
            print(f"Encoding completed ({format_time(time.time() - encoding_start)})")

    else:
        # Take a more straightforward approach for CPU or single GPU
        if rank == 0:
            device_string = 'GPU' if device.type == 'cuda' else 'CPU'
            print(f"\nEncoding {split.capitalize()} data of {total_texts} texts on a single {device_string}...")
            print(f"Batch Size: {batch_size}, Pooling: {pooling.upper() if pooling == 'cls' else pooling.capitalize()}, Empty Cache: {empty_cache}")

            # Display sample texts
            print(f"\nDisplaying samples from {split.capitalize()} data:")
            display_sample_texts(sample_texts)

        total_batches = len(texts) // batch_size + 1

        for i in range(0, len(texts), batch_size):
            batch_start = time.time()
            batch_texts = texts[i:i + batch_size]
            input_ids, attention_mask = tokenize(batch_texts, tokenizer, device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)

            # Perform only the selected pooling strategy for all batches
            batch_embeddings = pool(outputs.last_hidden_state, attention_mask, pooling)
            
            embeddings.append(batch_embeddings)
            final_embeddings = torch.cat(embeddings, dim=0)
            print(f"Batch {(i // batch_size) + 1:2d} / {total_batches}, Shape: {list(batch_embeddings.shape)}, Time: {format_time(time.time() - batch_start)}")
    
            if empty_cache and device.type == 'cuda':
                # Delete the unused objects
                del outputs, input_ids, attention_mask, batch_embeddings
                # Empty CUDA cache
                torch.cuda.empty_cache()

    return final_embeddings


def initialize_classifier(bert_model, bert_tokenizer, finetune_bert, finetune_layers, num_layers, hidden_dim, batch_size,
                          epochs, lr, early_stop, hidden_activation, n_iter_no_change, tol, rank, world_size, device, debug,
                          checkpoint_dir, checkpoint_interval, resume_from_checkpoint, filename=None, use_saved_params=True,
                          optimizer_name=None, scheduler_name=None, pooling='cls', target_score=None, interactive=False,
                          response_pipe=None, accumulation_steps=1):
    class_init_start = time.time()
    print(f"\nInitializing DDP Neural Classifier...") if rank == 0 else None
    hidden_activation = get_activation(hidden_activation)
    optimizer_class = get_optimizer(optimizer_name, device, rank, world_size)
    #scheduler_class = get_scheduler(scheduler_name, device, rank, world_size)
    print(f"Layers: {num_layers}, Hidden Dim: {hidden_dim}, Hidden Act: {hidden_activation.__class__.__name__}, Optimizer: {optimizer_class.__name__}, Pooling: {pooling.upper()}, Batch Size: {batch_size}, Gradient Accumulation Steps: {accumulation_steps}") if rank == 0 else None
    print(f"Max Epochs: {epochs}, LR: {lr}, Early Stop: {early_stop}, Fine-tune BERT: {finetune_bert}, Fine-tune Layers: {finetune_layers}, Target Score: {target_score}, Interactive: {interactive}") if rank == 0 else None
    
    classifier = TorchDDPNeuralClassifier(
        bert_model=bert_model,
        bert_tokenizer=bert_tokenizer,
        finetune_bert=finetune_bert,
        finetune_layers=finetune_layers,
        pooling=pooling,
        num_layers=num_layers,
        early_stopping=early_stop,
        hidden_dim=hidden_dim,
        hidden_activation=hidden_activation,
        batch_size=batch_size,
        max_iter=epochs,
        n_iter_no_change=n_iter_no_change,
        tol=tol,
        eta=lr,
        rank=rank,
        debug=debug,
        checkpoint_dir=checkpoint_dir,
        checkpoint_interval=checkpoint_interval,
        resume_from_checkpoint=resume_from_checkpoint,
        device=device,
        optimizer_class=optimizer_class,
        target_score=target_score,
        interactive=interactive,
        response_pipe=response_pipe,
        gradient_accumulation_steps=accumulation_steps
    )

    if filename is not None:
        print(f"Loading model from: {checkpoint_dir}/{filename}...") if rank == 0 else None
        start_epoch, model_state_dict, optimizer_state_dict = classifier.load_model(directory=checkpoint_dir, filename=filename, pattern=None, use_saved_params=use_saved_params, rank=rank, debug=debug)
    elif resume_from_checkpoint:
        print("Resuming training from the latest checkpoint...") if rank == 0 else None
        start_epoch, model_state_dict, optimizer_state_dict = classifier.load_model(directory=checkpoint_dir, filename=None, pattern='checkpoint_epoch', use_saved_params=use_saved_params, rank=rank, debug=debug)
    else:
        start_epoch = 1
        model_state_dict = None
        optimizer_state_dict = None

    dist.barrier()
    if rank == 0:
        print(classifier) if debug else None
        print(f"Classifier initialized ({format_time(time.time() - class_init_start)})")

    return classifier, start_epoch, model_state_dict, optimizer_state_dict

def evaluate_model(model, bert_tokenizer, X_dev, y_dev, label_dict, numeric_dict, world_size, device, rank, debug, save_preds,
                   save_dir, X_dev_sent):
    eval_start = time.time()
    print("\nEvaluating model...") if rank == 0 else None
    model.model.eval()
    with torch.no_grad():
        print("Making predictions...") if rank == 0 and debug else None
        if model.finetune_bert:
            dataset = SentimentDataset(X_dev, [0] * len(X_dev), bert_tokenizer)  # Dummy labels
            dataloader = DataLoader(dataset, batch_size=model.batch_size, shuffle=False)
            preds = []
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = model.model(input_ids, attention_mask=attention_mask)
                preds.append(outputs)
            preds = torch.cat(preds, dim=0)
        else:
            if not torch.is_tensor(X_dev):
                X_dev = torch.tensor(X_dev, device=device)
            preds = model.model(X_dev)
        all_preds = [torch.zeros_like(preds) for _ in range(world_size)]
        dist.all_gather(all_preds, preds)
        if rank == 0:
            all_preds = torch.cat(all_preds, dim=0)[:len(y_dev)]
            preds_labels = convert_numeric_to_labels(all_preds.argmax(dim=1).cpu().numpy(), numeric_dict)
            print(f"Predictions: {len(preds_labels)}, True labels: {len(y_dev)}") if debug else None
            # Save predictions if requested
            if save_preds:
                df = pd.DataFrame({
                    'X_dev_sent': X_dev_sent,
                    'y_dev': y_dev,
                    'preds_labels': preds_labels
                })
                # Create a save directory if it doesn't exist
                if not os.path.exists(save_dir):
                    print(f"Creating save directory: {save_dir}")
                    os.makedirs(save_dir)
                # Create a filename timestamp
                timestamp = time.strftime("%Y%m%d-%H%M%S")
                save_path = os.path.join(save_dir, f'predictions_{timestamp}.csv')
                df.to_csv(save_path, index=False)
                print(f"Saved predictions: {save_dir}/predictions_{timestamp}.csv")
            print("\nClassification report:")
            print(classification_report(y_dev, preds_labels, digits=3, zero_division=0))
            macro_f1_score = model.score(X_dev, y_dev, device, debug)
            print(f"Macro F1 Score: {macro_f1_score:.2f}")
            print(f"\nEvaluation completed ({format_time(time.time() - eval_start)})")




In [2]:
device = prepare_device(0, 'cpu')

In [3]:
def setup_environment(rank, world_size, backend, device, debug, port='12355'):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = port
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    print(f"Rank {rank} - Device: {device}")
    dist.barrier()
    if rank == 0:
        print(f"{world_size} process groups initialized with '{backend}' backend on {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")


In [4]:
setup_environment(0, 1, 'gloo', device, False, '12356')

Rank 0 - Device: cpu
1 process groups initialized with 'gloo' backend on localhost:12356


In [5]:
fix_random_seeds(42)

In [16]:
train_dyn1, dev_dyn1 = load_data('dynasent_r1', None, None, 1, 0, False)


Loading data...
Using the same dataset for training and evaluation
Train Data: DynaSent Round 1 from Hugging Face: 'dynabench/dynasent'
Dataset URL: https://huggingface.co/datasets/dynabench/dynasent
Dev Data: DynaSent Round 1 from Hugging Face: 'dynabench/dynasent'
Dataset URL: https://huggingface.co/datasets/dynabench/dynasent
Train size: 80488, Dev size: 3600
Train label distribution:
	      Negative: 14021
	       Neutral: 45076
	      Positive: 21391
Dev label distribution:
	      Negative: 1200
	       Neutral: 1200
	      Positive: 1200
Data loaded (2s)


In [17]:
train_sst, dev_sst = load_data('sst_local', None, None, 1, 0, False)


Loading data...
Using the same dataset for training and evaluation
Train Data: Stanford Sentiment Treebank (SST) from Local: 'data/sentiment'
Dev Data: Stanford Sentiment Treebank (SST) from Local: 'data/sentiment'
Train size: 8544, Dev size: 1101
Train label distribution:
	      Negative: 3310
	       Neutral: 1624
	      Positive: 3610
Dev label distribution:
	      Negative: 428
	       Neutral: 229
	      Positive: 444
Data loaded (392ms)


In [18]:
train_dyn2, dev_dyn2 = load_data('dynasent_r2', None, None, 1, 0, False)


Loading data...
Using the same dataset for training and evaluation
Train Data: DynaSent Round 2 from Hugging Face: 'dynabench/dynasent'
Dataset URL: https://huggingface.co/datasets/dynabench/dynasent
Dev Data: DynaSent Round 2 from Hugging Face: 'dynabench/dynasent'
Dataset URL: https://huggingface.co/datasets/dynabench/dynasent
Train size: 13065, Dev size: 720
Train label distribution:
	      Negative: 4579
	       Neutral: 2448
	      Positive: 6038
Dev label distribution:
	      Negative: 240
	       Neutral: 240
	      Positive: 240
Data loaded (1s)


In [28]:
train_dyn1['source'] = 'dynasent_r1'
train_dyn2['source'] = 'dynasent_r2'
train_sst['source'] = 'sst_local'
dev_dyn1['source'] = 'dynasent_r1'
dev_dyn2['source'] = 'dynasent_r2'
dev_sst['source'] = 'sst_local'

In [29]:
train_dyn1[['sentence', 'label', 'source']].head()

Unnamed: 0,sentence,label,source
0,Roto-Rooter is always good when you need someo...,positive,dynasent_r1
1,It's so worth the price of cox service over he...,positive,dynasent_r1
2,"I placed my order of ""sticky ribs"" as an appet...",neutral,dynasent_r1
3,"There is mandatory valet parking, so make sure...",neutral,dynasent_r1
4,My wife and I couldn't finish it.,neutral,dynasent_r1


In [30]:
train_dyn2[['sentence', 'label', 'source']].head()

Unnamed: 0,sentence,label,source
0,We enjoyed our first and last meal in Toronto ...,positive,dynasent_r2
1,I tried a new place. I can't wait to return an...,positive,dynasent_r2
2,"The buffalo chicken was not good, but very cos...",negative,dynasent_r2
3,The hotel offered complimentary breakfast.,positive,dynasent_r2
4,It work very well,positive,dynasent_r2


In [31]:
train_sst[['sentence', 'label', 'source']].head()

Unnamed: 0,sentence,label,source
0,The Rock is destined to be the 21st Century 's...,positive,sst_local
71,The gorgeously elaborate continuation of `` Th...,positive,sst_local
144,Singer\/composer Bryan Adams contributes a sle...,positive,sst_local
221,You 'd think by now America would have had eno...,neutral,sst_local
258,Yet the act is still charming here .,positive,sst_local


In [32]:
train_all = pd.concat([train_dyn1[['sentence', 'label', 'source']], train_dyn2[['sentence', 'label', 'source']], train_sst[['sentence', 'label', 'source']]], ignore_index=True)

In [35]:
train_all.sample(10)

Unnamed: 0,sentence,label,source
78907,I took the car to Audi Charlotte the next day.,neutral,dynasent_r1
56638,One highlight was that Iove the color I picked...,positive,dynasent_r1
91210,"There was one server, and he became less atten...",negative,dynasent_r2
71599,Guess jeans are the only way to go.,positive,dynasent_r1
100656,Every joke is repeated at least four times .,negative,sst_local
54270,The food was actually better than the food i h...,positive,dynasent_r1
11674,"I really didn't want to buy a brand new tire, ...",neutral,dynasent_r1
877,I just thought I was doomed forever.,negative,dynasent_r1
95959,"The performances of the children , untrained i...",positive,sst_local
28011,Went here on a quick lunch break.,neutral,dynasent_r1


In [36]:
dev_all = pd.concat([dev_dyn1[['sentence', 'label', 'source']], dev_dyn2[['sentence', 'label', 'source']], dev_sst[['sentence', 'label', 'source']]], ignore_index=True)

In [37]:
dev_all.sample(10)

Unnamed: 0,sentence,label,source
3261,Went to this venue for an adult game night eve...,positive,dynasent_r1
1207,Was told there'd be a ten minute but we weren'...,positive,dynasent_r1
1315,There I was with an overheated an engine one M...,negative,dynasent_r1
1245,After doing a bit of research I found this pla...,neutral,dynasent_r1
1749,Flies inside especially near the deli section ...,negative,dynasent_r1
3117,I never have written a review before on Yelp b...,negative,dynasent_r1
4032,I had to order the pulled chicken and some sid...,negative,dynasent_r2
2789,Had the windshield on the Cadillac replaced by...,neutral,dynasent_r1
1108,It may not look like much from the outside but...,positive,dynasent_r1
1765,I searched on Yelp for a reasonable and reliab...,neutral,dynasent_r1


In [38]:
print(train_all.shape, dev_all.shape)

(102097, 3) (5421, 3)


In [39]:
train_all.head()

Unnamed: 0,sentence,label,source
0,Roto-Rooter is always good when you need someo...,positive,dynasent_r1
1,It's so worth the price of cox service over he...,positive,dynasent_r1
2,"I placed my order of ""sticky ribs"" as an appet...",neutral,dynasent_r1
3,"There is mandatory valet parking, so make sure...",neutral,dynasent_r1
4,My wife and I couldn't finish it.,neutral,dynasent_r1


In [40]:
train_all.tail()

Unnamed: 0,sentence,label,source
102092,A real snooze .,negative,sst_local
102093,No surprises .,negative,sst_local
102094,We 've seen the hippie-turned-yuppie plot befo...,positive,sst_local
102095,Her fans walked out muttering words like `` ho...,negative,sst_local
102096,In this case zero .,negative,sst_local


In [41]:
# Randomly shuffle the training data and dev data
train_all = train_all.sample(frac=1, random_state=42).reset_index(drop=True)
dev_all = dev_all.sample(frac=1, random_state=42).reset_index(drop=True)

In [42]:
train_all.head()

Unnamed: 0,sentence,label,source
0,Those 2 drinks are part of the HK culture and ...,negative,dynasent_r2
1,I was told by the repair company that was doin...,negative,dynasent_r1
2,It is there to give them a good time .,neutral,sst_local
3,Like leafing through an album of photos accomp...,negative,sst_local
4,Johnny was a talker and liked to have fun.,positive,dynasent_r1


In [48]:
train_all.tail()

Unnamed: 0,sentence,label,source
102092,I thought this place was supposed to be good.,negative,dynasent_r1
102093,They claim it's because people didn't like it ...,negative,dynasent_r1
102094,There is also another marbled-out full bathroo...,neutral,dynasent_r1
102095,You put in your cell phone number & select a d...,neutral,dynasent_r1
102096,I came in for a second opinion on a crown I wa...,neutral,dynasent_r1


In [44]:
dev_all.head()

Unnamed: 0,sentence,label,source
0,Found Thai Spoon on the Vegan Pittsburgh website.,neutral,dynasent_r1
1,Our bill came out to around $27 and we ate lik...,positive,dynasent_r1
2,State Farm broke down the costs for me of the ...,neutral,dynasent_r1
3,The only con for this resto is the wait to get...,negative,dynasent_r1
4,We could hear the people above us stomping aro...,negative,dynasent_r1


In [49]:
dev_all.tail()

Unnamed: 0,sentence,label,source
5416,I think it's really a matter of mastering the ...,neutral,dynasent_r2
5417,A bloated gasbag thesis grotesquely impressed ...,negative,sst_local
5418,"Its story may be a thousand years old , but wh...",negative,sst_local
5419,I felt sad for Lise not so much because of wha...,neutral,sst_local
5420,"We always eat in the restaurant, so I can't co...",neutral,dynasent_r1


In [50]:
# Save the combined training and dev data to CSV files
train_all.to_csv('data/merged/train_all.csv', index=False)
dev_all.to_csv('data/merged/dev_all.csv', index=False)